From 389c9619afe8b6d129a12d137bf332491503bb83 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 8 Nov 2024 00:31:41 +0300 Subject: [PATCH 01/28] Refactor setup key handling to use store methods Signed-off-by: bcmmbaga --- management/server/setupkey.go | 179 +++++++++++++++++----------- management/server/sql_store.go | 83 ++++++++----- management/server/sql_store_test.go | 4 +- management/server/status/error.go | 15 ++- management/server/store.go | 7 +- 5 files changed, 178 insertions(+), 110 deletions(-) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 43b6e02c936..f54eafdc1fd 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -4,7 +4,6 @@ import ( "context" "crypto/sha256" b64 "encoding/base64" - "fmt" "hash/fnv" "strconv" "strings" @@ -12,9 +11,8 @@ import ( "unicode/utf8" "github.com/google/uuid" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" ) @@ -226,34 +224,49 @@ func Hash(s string) uint32 { // and adds it to the specified account. A list of autoGroups IDs can be empty. func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - if err := validateSetupKeyAutoGroups(account, autoGroups); err != nil { - return nil, err + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - setupKey, plainKey := GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) - account.SetupKeys[setupKey.Key] = setupKey - err = am.Store.SaveAccount(ctx, account) + var accountGroups []*nbgroup.Group + var setupKey *SetupKey + var plainKey string + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + accountGroups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + if err = validateSetupKeyAutoGroups(accountGroups, autoGroups); err != nil { + return err + } + + setupKey, plainKey = GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) + setupKey.AccountID = accountID + + return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, setupKey) + }) if err != nil { - return nil, status.Errorf(status.Internal, "failed adding account key") + return nil, err } am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta()) + groupMap := make(map[string]*nbgroup.Group, len(accountGroups)) + for _, g := range accountGroups { + groupMap[g.ID] = g + } for _, g := range setupKey.AutoGroups { - group := account.GetGroup(g) - if group != nil { + group, ok := groupMap[g] + if ok { am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.GroupAddedToSetupKey, map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": setupKey.Name}) - } else { - log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) } } @@ -268,43 +281,48 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s // (e.g. the key itself, creation date, ID, etc). // These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key. func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - if keyToSave == nil { return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil") } - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + var accountGroups []*nbgroup.Group var oldKey *SetupKey - for _, key := range account.SetupKeys { - if key.Id == keyToSave.Id { - oldKey = key.Copy() - break + var newKey *SetupKey + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + accountGroups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + if err != nil { + return err } - } - if oldKey == nil { - return nil, status.Errorf(status.NotFound, "setup key not found") - } - if err := validateSetupKeyAutoGroups(account, keyToSave.AutoGroups); err != nil { - return nil, err - } + if err = validateSetupKeyAutoGroups(accountGroups, keyToSave.AutoGroups); err != nil { + return err + } - // only auto groups, revoked status, and name can be updated for now - newKey := oldKey.Copy() - newKey.Name = keyToSave.Name - newKey.AutoGroups = keyToSave.AutoGroups - newKey.Revoked = keyToSave.Revoked - newKey.UpdatedAt = time.Now().UTC() + oldKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id) + if err != nil { + return err + } - account.SetupKeys[newKey.Key] = newKey + // only auto groups, revoked status, and name can be updated for now + newKey = oldKey.Copy() + newKey.Name = keyToSave.Name + newKey.AutoGroups = keyToSave.AutoGroups + newKey.Revoked = keyToSave.Revoked + newKey.UpdatedAt = time.Now().UTC() - if err = am.Store.SaveAccount(ctx, account); err != nil { + return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, newKey) + }) + if err != nil { return nil, err } @@ -315,24 +333,25 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str defer func() { addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups) removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups) + + groupMap := make(map[string]*nbgroup.Group, len(accountGroups)) + for _, g := range accountGroups { + groupMap[g.ID] = g + } + for _, g := range removedGroups { - group := account.GetGroup(g) - if group != nil { + group, ok := groupMap[g] + if ok { am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupRemovedFromSetupKey, map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name}) - } else { - log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) } - } for _, g := range addedGroups { - group := account.GetGroup(g) - if group != nil { + group, ok := groupMap[g] + if ok { am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupAddedToSetupKey, map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name}) - } else { - log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) } } }() @@ -347,16 +366,15 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.NewUnauthorizedToViewSetupKeysError() + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) - if err != nil { - return nil, err + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } - return setupKeys, nil + return am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) } // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. @@ -366,8 +384,12 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.NewUnauthorizedToViewSetupKeysError() + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) @@ -387,21 +409,29 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { - return fmt.Errorf("failed to get user: %w", err) + return err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return status.NewUnauthorizedToViewSetupKeysError() + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } - deletedSetupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) - if err != nil { - return fmt.Errorf("failed to get setup key: %w", err) + if user.IsRegularUser() { + return status.NewAdminPermissionError() } - err = am.Store.DeleteSetupKey(ctx, accountID, keyID) + var deletedSetupKey *SetupKey + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) + if err != nil { + return err + } + + return transaction.DeleteSetupKey(ctx, LockingStrengthUpdate, accountID, keyID) + }) if err != nil { - return fmt.Errorf("failed to delete setup key: %w", err) + return err } am.StoreEvent(ctx, userID, keyID, accountID, activity.SetupKeyDeleted, deletedSetupKey.EventMeta()) @@ -409,15 +439,22 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, return nil } -func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error { - for _, group := range autoGroups { - g, ok := account.Groups[group] - if !ok { - return status.Errorf(status.NotFound, "group %s doesn't exist", group) +func validateSetupKeyAutoGroups(groups []*nbgroup.Group, autoGroups []string) error { + groupMap := make(map[string]*nbgroup.Group, len(groups)) + for _, g := range groups { + groupMap[g.ID] = g + } + + for _, groupID := range autoGroups { + g, exists := groupMap[groupID] + if !exists { + return status.Errorf(status.NotFound, "group %s doesn't exist", groupID) } + if g.Name == "All" { - return status.Errorf(status.InvalidArgument, "can't add All group to the setup key") + return status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key") } } + return nil } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 646184578eb..a11370e4f9d 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -633,11 +633,11 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*Us return users, nil } -func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { +func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) { startTime := time.Now() var groups []*nbgroup.Group - result := s.db.Find(&groups, accountIDCondition, accountID) + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") @@ -645,8 +645,8 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*n if errors.Is(result.Error, context.Canceled) { return nil, status.NewStoreContextCanceledError(time.Since(startTime)) } - log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "issue getting groups from store") + log.WithContext(ctx).Errorf("failed to get account groups from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get account groups from the store") } return groups, nil @@ -1404,12 +1404,59 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt // GetAccountSetupKeys retrieves setup keys for an account. func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) { - return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID) + var setupKeys []*SetupKey + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&setupKeys, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get setup keys from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get setup keys from store") + } + + return setupKeys, nil } // GetSetupKeyByID retrieves a setup key by its ID and account ID. -func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) { - return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID) +func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) { + var setupKey *SetupKey + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "setup key not found") + } + log.WithContext(ctx).Errorf("failed to get setup key from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get setup key from store") + } + + return setupKey, nil +} + +// SaveSetupKey saves a setup key to the database. +func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error { + result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}). + Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save setup key to store: %s", result.Error) + return status.Errorf(status.Internal, "failed to save setup key to store") + } + + return nil +} + +// DeleteSetupKey deletes a setup key from the database. +func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, keyID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete setup key from store: %s", result.Error) + return status.Errorf(status.Internal, "failed to delete setup key from store") + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "setup key not found") + } + + return nil } // GetAccountNameServerGroups retrieves name server groups for an account. @@ -1422,10 +1469,6 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID) } -func (s *SqlStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error { - return deleteRecordByID[SetupKey](s.db.WithContext(ctx), LockingStrengthUpdate, keyID, accountID) -} - // getRecords retrieves records from the database based on the account ID. func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) { var record []T @@ -1458,21 +1501,3 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a } return &record, nil } - -// deleteRecordByID deletes a record by its ID and account ID from the database. -func deleteRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) error { - var record T - result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(record, accountAndIDQueryCondition, accountID, recordID) - if err := result.Error; err != nil { - parts := strings.Split(fmt.Sprintf("%T", record), ".") - recordType := parts[len(parts)-1] - - return status.Errorf(status.Internal, "failed to delete %s from store: %v", recordType, err) - } - - if result.RowsAffected == 0 { - return status.Errorf(status.NotFound, "record not found") - } - - return nil -} diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index b371e231319..3f3b2a453d4 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1274,7 +1274,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" setupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" - err = store.DeleteSetupKey(context.Background(), accountID, setupKeyID) + err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, setupKeyID) require.NoError(t, err) _, err = store.GetSetupKeyByID(context.Background(), LockingStrengthShare, setupKeyID, accountID) @@ -1290,6 +1290,6 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" nonExistingKeyID := "non-existing-key-id" - err = store.DeleteSetupKey(context.Background(), accountID, nonExistingKeyID) + err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID) require.Error(t, err) } diff --git a/management/server/status/error.go b/management/server/status/error.go index a145edf8002..5a75c94b1c1 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -111,11 +111,21 @@ func NewGetAccountFromStoreError(err error) error { return Errorf(Internal, "issue getting account from store: %s", err) } +// NewUserNotPartOfAccountError creates a new Error with PermissionDenied type for a user not being part of an account +func NewUserNotPartOfAccountError() error { + return Errorf(PermissionDenied, "user is not part of this account") +} + // NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store func NewGetUserFromStoreError() error { return Errorf(Internal, "issue getting user from store") } +// NewAdminPermissionError creates a new Error with PermissionDenied type for actions requiring admin role. +func NewAdminPermissionError() error { + return Errorf(PermissionDenied, "admin role required to perform this action") +} + // NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context func NewStoreContextCanceledError(duration time.Duration) error { return Errorf(Internal, "store access: context canceled after %v", duration) @@ -125,8 +135,3 @@ func NewStoreContextCanceledError(duration time.Duration) error { func NewInvalidKeyIDError() error { return Errorf(InvalidArgument, "invalid key ID") } - -// NewUnauthorizedToViewSetupKeysError creates a new Error with Unauthorized type for an issue getting a setup key -func NewUnauthorizedToViewSetupKeysError() error { - return Errorf(Unauthorized, "only users with admin power can view setup keys") -} diff --git a/management/server/store.go b/management/server/store.go index 087c9884763..73c9ef6a692 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -70,7 +70,7 @@ type Store interface { DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteTokenID2UserIDIndex(tokenID string) error - GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) + GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error @@ -96,7 +96,9 @@ type Store interface { GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) - GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) + GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) + SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error + DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) @@ -124,7 +126,6 @@ type Store interface { // This is also a method of metrics.DataSource interface. GetStoreEngine() StoreEngine ExecuteInTransaction(ctx context.Context, f func(store Store) error) error - DeleteSetupKey(ctx context.Context, accountID, keyID string) error } type StoreEngine string From 78044c226d9240edcdd5bb180aaab1da86f442e4 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 8 Nov 2024 00:32:14 +0300 Subject: [PATCH 02/28] add lock to get account groups Signed-off-by: bcmmbaga --- management/server/account.go | 4 ++-- management/server/account_test.go | 2 +- management/server/group.go | 2 +- management/server/peer.go | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index aa7609388c0..583853f2504 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2029,7 +2029,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return fmt.Errorf("error getting user: %w", err) } - groups, err := transaction.GetAccountGroups(ctx, accountID) + groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } @@ -2059,7 +2059,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st // Propagate changes to peers if group propagation is enabled if settings.GroupsPropagationEnabled { - groups, err = transaction.GetAccountGroups(ctx, accountID) + groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } diff --git a/management/server/account_test.go b/management/server/account_test.go index 6a2d85fe8f7..fdf004a3b8a 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2773,7 +2773,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID") + groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID") assert.NoError(t, err) assert.Len(t, groups, 3, "new group3 should be added") diff --git a/management/server/group.go b/management/server/group.go index bdb569e377f..b2ec88cc0d2 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -59,7 +59,7 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us return nil, err } - return am.Store.GetAccountGroups(ctx, accountID) + return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) } // GetGroupByName filters all groups in an account by name and returns the one with the most peers diff --git a/management/server/peer.go b/management/server/peer.go index 9c5ab571bab..8ced2a1deb0 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -765,7 +765,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } } - groups, err := am.Store.GetAccountGroups(ctx, accountID) + groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) if err != nil { return nil, nil, nil, err } From 1a5f3c653c4b78a5c52bca9bba74c966fbd7495c Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 8 Nov 2024 00:37:47 +0300 Subject: [PATCH 03/28] add check for regular user Signed-off-by: bcmmbaga --- management/server/user.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/management/server/user.go b/management/server/user.go index 9fdd3a6eeea..1368b76b121 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -103,6 +103,11 @@ func (u *User) IsAdminOrServiceUser() bool { return u.HasAdminPower() || u.IsServiceUser } +// IsRegularUser checks if the user is a regular user. +func (u *User) IsRegularUser() bool { + return !u.HasAdminPower() && !u.IsServiceUser +} + // ToUserInfo converts a User object to a UserInfo object. func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) { autoGroups := u.AutoGroups From 931521d505b012f45dcf5bcb5de0ee07f0c5b876 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 8 Nov 2024 00:59:37 +0300 Subject: [PATCH 04/28] get only required groups for auto-group validation Signed-off-by: bcmmbaga --- management/server/group/group.go | 5 ++++ management/server/setupkey.go | 46 +++++++++++++------------------- 2 files changed, 23 insertions(+), 28 deletions(-) diff --git a/management/server/group/group.go b/management/server/group/group.go index d293e1afc6f..e98e5ecc4b5 100644 --- a/management/server/group/group.go +++ b/management/server/group/group.go @@ -49,3 +49,8 @@ func (g *Group) Copy() *Group { func (g *Group) HasPeers() bool { return len(g.Peers) > 0 } + +// IsGroupAll checks if the group is a default "All" group. +func (g *Group) IsGroupAll() bool { + return g.Name == "All" +} diff --git a/management/server/setupkey.go b/management/server/setupkey.go index f54eafdc1fd..da248be25d6 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -233,20 +233,16 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s return nil, status.NewUserNotPartOfAccountError() } - var accountGroups []*nbgroup.Group + var groups []*nbgroup.Group var setupKey *SetupKey var plainKey string err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - accountGroups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + groups, err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups) if err != nil { return err } - if err = validateSetupKeyAutoGroups(accountGroups, autoGroups); err != nil { - return err - } - setupKey, plainKey = GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) setupKey.AccountID = accountID @@ -257,8 +253,8 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s } am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta()) - groupMap := make(map[string]*nbgroup.Group, len(accountGroups)) - for _, g := range accountGroups { + groupMap := make(map[string]*nbgroup.Group, len(groups)) + for _, g := range groups { groupMap[g.ID] = g } @@ -294,20 +290,16 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return nil, status.NewUserNotPartOfAccountError() } - var accountGroups []*nbgroup.Group + var groups []*nbgroup.Group var oldKey *SetupKey var newKey *SetupKey err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - accountGroups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + groups, err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups) if err != nil { return err } - if err = validateSetupKeyAutoGroups(accountGroups, keyToSave.AutoGroups); err != nil { - return err - } - oldKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id) if err != nil { return err @@ -334,8 +326,8 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups) removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups) - groupMap := make(map[string]*nbgroup.Group, len(accountGroups)) - for _, g := range accountGroups { + groupMap := make(map[string]*nbgroup.Group, len(groups)) + for _, g := range groups { groupMap[g.ID] = g } @@ -439,22 +431,20 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, return nil } -func validateSetupKeyAutoGroups(groups []*nbgroup.Group, autoGroups []string) error { - groupMap := make(map[string]*nbgroup.Group, len(groups)) - for _, g := range groups { - groupMap[g.ID] = g - } +func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) ([]*nbgroup.Group, error) { + autoGroups := make([]*nbgroup.Group, 0, len(autoGroupIDs)) - for _, groupID := range autoGroups { - g, exists := groupMap[groupID] - if !exists { - return status.Errorf(status.NotFound, "group %s doesn't exist", groupID) + for _, groupID := range autoGroupIDs { + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) + if err != nil { + return nil, err } - if g.Name == "All" { - return status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key") + if group.IsGroupAll() { + return nil, status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key") } + autoGroups = append(autoGroups, group) } - return nil + return autoGroups, nil } From f8b5eedd382d8a218517cf7c7b552f3a0dd8ee3d Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 8 Nov 2024 10:14:13 +0300 Subject: [PATCH 05/28] add account lock and return auto groups map on validation Signed-off-by: bcmmbaga --- management/server/setupkey.go | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index da248be25d6..65d7796f1a0 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -224,6 +224,9 @@ func Hash(s string) uint32 { // and adds it to the specified account. A list of autoGroups IDs can be empty. func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err @@ -233,7 +236,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s return nil, status.NewUserNotPartOfAccountError() } - var groups []*nbgroup.Group + var groups map[string]*nbgroup.Group var setupKey *SetupKey var plainKey string @@ -253,13 +256,9 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s } am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta()) - groupMap := make(map[string]*nbgroup.Group, len(groups)) - for _, g := range groups { - groupMap[g.ID] = g - } for _, g := range setupKey.AutoGroups { - group, ok := groupMap[g] + group, ok := groups[g] if ok { am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.GroupAddedToSetupKey, map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": setupKey.Name}) @@ -281,6 +280,9 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil") } + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err @@ -290,7 +292,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return nil, status.NewUserNotPartOfAccountError() } - var groups []*nbgroup.Group + var groups map[string]*nbgroup.Group var oldKey *SetupKey var newKey *SetupKey @@ -326,13 +328,8 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups) removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups) - groupMap := make(map[string]*nbgroup.Group, len(groups)) - for _, g := range groups { - groupMap[g.ID] = g - } - for _, g := range removedGroups { - group, ok := groupMap[g] + group, ok := groups[g] if ok { am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupRemovedFromSetupKey, map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name}) @@ -340,7 +337,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str } for _, g := range addedGroups { - group, ok := groupMap[g] + group, ok := groups[g] if ok { am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupAddedToSetupKey, map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name}) @@ -431,8 +428,8 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, return nil } -func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) ([]*nbgroup.Group, error) { - autoGroups := make([]*nbgroup.Group, 0, len(autoGroupIDs)) +func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) (map[string]*nbgroup.Group, error) { + autoGroups := map[string]*nbgroup.Group{} for _, groupID := range autoGroupIDs { group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) @@ -443,7 +440,7 @@ func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountI if group.IsGroupAll() { return nil, status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key") } - autoGroups = append(autoGroups, group) + autoGroups[group.ID] = group } return autoGroups, nil From 106fc759365d535db529d93c9c1ad0324b2ccff6 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 8 Nov 2024 18:38:32 +0300 Subject: [PATCH 06/28] refactor account peers update Signed-off-by: bcmmbaga --- management/server/account.go | 25 ++++++++++++------------- management/server/dns.go | 2 +- management/server/nameserver.go | 6 +++--- management/server/peer.go | 23 ++++++++++++++--------- management/server/peer_test.go | 2 +- management/server/policy.go | 4 ++-- management/server/posture_checks.go | 2 +- management/server/route.go | 6 +++--- management/server/user.go | 8 ++++---- 9 files changed, 41 insertions(+), 37 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 583853f2504..2b18c344101 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -110,7 +110,6 @@ type AccountManager interface { SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error - ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error) GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) @@ -1435,7 +1434,7 @@ func isNil(i idp.Manager) bool { // addAccountIDToIDPAppMeta update user's app metadata in idp manager func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { if !isNil(am.idpManager) { - accountUsers, err := am.Store.GetAccountUsers(ctx, accountID) + accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) if err != nil { return err } @@ -2083,7 +2082,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return fmt.Errorf("error saving groups: %w", err) } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return fmt.Errorf("error incrementing network serial: %w", err) } } @@ -2127,14 +2126,19 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } if settings.GroupsPropagationEnabled { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + removedGroupAffectsPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, removeOldGroups) if err != nil { - return fmt.Errorf("error getting account: %w", err) + return err } - if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) { + newGroupsAffectsPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, addNewGroups) + if err != nil { + return err + } + + if removedGroupAffectsPeers || newGroupsAffectsPeers { log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } } @@ -2398,12 +2402,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) { log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID) - updatedAccount, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err) - return - } - am.updateAccountPeers(ctx, updatedAccount) + am.updateAccountPeers(ctx, accountID) } func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { diff --git a/management/server/dns.go b/management/server/dns.go index 256b8b12512..4551be5ab92 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -146,7 +146,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 5ebd263dcc2..957008714e5 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -71,7 +71,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco } if anyGroupHasPeers(account, newNSGroup.Groups) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) @@ -106,7 +106,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun } if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) @@ -136,7 +136,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco } if anyGroupHasPeers(account, nsGroup.Groups) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) diff --git a/management/server/peer.go b/management/server/peer.go index 8ced2a1deb0..994cc02879c 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -131,7 +131,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK if expired { // we need to update other peers because when peer login expires all other peers are notified to disconnect from // the expired one. Here we notify them that connection is now allowed again. - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } return nil @@ -267,7 +267,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } if peerLabelUpdated || requiresPeerUpdates { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return peer, nil @@ -344,7 +344,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer } if updateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil @@ -551,7 +551,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return fmt.Errorf("failed to add peer to account: %w", err) } - err = transaction.IncrementNetworkSerial(ctx, accountID) + err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID) if err != nil { return fmt.Errorf("failed to increment network serial: %w", err) } @@ -597,7 +597,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s groupsToAdd = append(groupsToAdd, allGroup.ID) if areGroupChangesAffectPeers(account, groupsToAdd) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } approvedPeersMap, err := am.GetValidatedPeers(account) @@ -661,7 +661,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac } if sync.UpdateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } } @@ -680,7 +680,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac } if isStatusChanged { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } validPeersMap, err := am.GetValidatedPeers(account) @@ -811,7 +811,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } if updateRemotePeers || isStatusChanged { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer) @@ -974,7 +974,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, // updateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. -func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) { +func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) { start := time.Now() defer func() { if am.metrics != nil { @@ -982,6 +982,11 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account } }() + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err) + return + } peers := account.GetPeers() approvedPeersMap, err := am.GetValidatedPeers(account) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 78885ea1b72..4e2dcb2c313 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -877,7 +877,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { start := time.Now() for i := 0; i < b.N; i++ { - manager.updateAccountPeers(ctx, account) + manager.updateAccountPeers(ctx, account.Id) } duration := time.Since(start) diff --git a/management/server/policy.go b/management/server/policy.go index 43a925f8850..8a5733f011c 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -377,7 +377,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil @@ -406,7 +406,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) if anyGroupHasPeers(account, policy.ruleGroups()) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 2dccd8f590c..096cff3f5c9 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -69,7 +69,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/route.go b/management/server/route.go index 1cf00b37c46..dcf2cb0d32c 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -238,7 +238,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri } if isRouteChangeAffectPeers(account, &newRoute) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) @@ -324,7 +324,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI } if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) @@ -356,7 +356,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) if isRouteChangeAffectPeers(account, routy) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/user.go b/management/server/user.go index 1368b76b121..38b820cb41b 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -492,7 +492,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) if updateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } return nil @@ -833,7 +833,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } for _, storeEvent := range eventsToStore { @@ -1124,7 +1124,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service am.peersUpdateManager.CloseChannels(ctx, peerIDs) - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } return nil } @@ -1232,7 +1232,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account } if updateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } for targetUserID, meta := range deletedUsersMeta { From 0a70e4c5d45292223c78427984fb470aaf0a9a40 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 8 Nov 2024 18:39:36 +0300 Subject: [PATCH 07/28] Refactor groups to use store methods Signed-off-by: bcmmbaga --- management/server/group.go | 390 ++++++++++++------ management/server/integrated_validator.go | 27 +- management/server/mock_server/account_mock.go | 9 - management/server/sql_store.go | 81 +++- management/server/store.go | 7 +- 5 files changed, 355 insertions(+), 159 deletions(-) diff --git a/management/server/group.go b/management/server/group.go index b2ec88cc0d2..da4c0fb9415 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -37,8 +37,12 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco return err } - if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, "groups are blocked for users") + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() && settings.RegularUsersViewBlocked { + return status.NewAdminPermissionError() } return nil @@ -49,8 +53,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - - return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) + return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) } // GetAllGroups returns all groups in an account @@ -58,13 +61,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) } // GetGroupByName filters all groups in an account by name and returns the one with the most peers func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { - return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID) + return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName) } // SaveGroup object of the peers @@ -78,12 +80,19 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI // Note: This function does not acquire the global lock. // It is the caller's responsibility to ensure proper locking is in place before invoking this method. func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error { - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - var eventsToStore []func() + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + var ( + eventsToStore []func() + groupsToSave []*nbgroup.Group + ) for _, newGroup := range newGroups { if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { @@ -91,7 +100,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { - existingGroup, err := account.FindGroupByName(newGroup.Name) + existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name) if err != nil { s, ok := status.FromError(err) if !ok || s.ErrorType != status.NotFound { @@ -109,15 +118,15 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } for _, peerID := range newGroup.Peers { - if account.Peers[peerID] == nil { + if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID); err != nil { return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) } } - oldGroup := account.Groups[newGroup.ID] - account.Groups[newGroup.ID] = newGroup + newGroup.AccountID = accountID + groupsToSave = append(groupsToSave, newGroup) - events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account) + events := am.prepareGroupEvents(ctx, userID, accountID, newGroup) eventsToStore = append(eventsToStore, events...) } @@ -126,30 +135,45 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user newGroupIDs = append(newGroupIDs, newGroup.ID) } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, newGroupIDs) + if err != nil { return err } - if areGroupChangesAffectPeers(account, newGroupIDs) { - am.updateAccountPeers(ctx, account) + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave); err != nil { + return fmt.Errorf("failed to save groups: %w", err) + } + return nil + }) + if err != nil { + return err } for _, storeEvent := range eventsToStore { storeEvent() } + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) + } + return nil } // prepareGroupEvents prepares a list of event functions to be stored. -func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() { +func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup *nbgroup.Group) []func() { var eventsToStore []func() addedPeers := make([]string, 0) removedPeers := make([]string, 0) - if oldGroup != nil { + oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID) + if err == nil && oldGroup != nil { addedPeers = difference(newGroup.Peers, oldGroup.Peers) removedPeers = difference(oldGroup.Peers, newGroup.Peers) } else { @@ -159,12 +183,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID }) } - for _, p := range addedPeers { - peer := account.Peers[p] - if peer == nil { - log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) + for _, peerID := range addedPeers { + peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + if err != nil { + log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID) continue } + peerCopy := peer // copy to avoid closure issues eventsToStore = append(eventsToStore, func() { am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer, @@ -175,12 +200,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID }) } - for _, p := range removedPeers { - peer := account.Peers[p] - if peer == nil { - log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) + for _, peerID := range removedPeers { + peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + if err != nil { + log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID) continue } + peerCopy := peer // copy to avoid closure issues eventsToStore = append(eventsToStore, func() { am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer, @@ -210,119 +236,108 @@ func difference(a, b []string) []string { } // DeleteGroup object of the peers. -func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountId) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountId) +func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - group, ok := account.Groups[groupID] - if !ok { - return nil + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } - allGroup, err := account.GetGroupAll() + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) if err != nil { return err } - if allGroup.ID == groupID { + if group.Name == "All" { return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") } - if err = validateDeleteGroup(account, group, userId); err != nil { + if err = am.validateDeleteGroup(ctx, group, userID); err != nil { return err } - delete(account.Groups, groupID) - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + if err = transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID); err != nil { + return fmt.Errorf("failed to delete group: %w", err) + } + return nil + }) + if err != nil { return err } - am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta()) + am.StoreEvent(ctx, userID, groupID, accountID, activity.GroupDeleted, group.EventMeta()) return nil } // DeleteGroups deletes groups from an account. -// Note: This function does not acquire the global lock. -// It is the caller's responsibility to ensure proper locking is in place before invoking this method. -// -// If an error occurs while deleting a group, the function skips it and continues deleting other groups. -// Errors are collected and returned at the end. -func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error { - account, err := am.Store.GetAccount(ctx, accountId) +func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - var allErrors error + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + var ( + allErrors error + groupIDsToDelete []string + deletedGroups []*nbgroup.Group + ) - deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs)) for _, groupID := range groupIDs { - group, ok := account.Groups[groupID] - if !ok { + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) + if err != nil { continue } - if err := validateDeleteGroup(account, group, userId); err != nil { + if err := am.validateDeleteGroup(ctx, group, userID); err != nil { allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) continue } - delete(account.Groups, groupID) + groupIDsToDelete = append(groupIDsToDelete, groupID) deletedGroups = append(deletedGroups, group) } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - for _, g := range deletedGroups { - am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta()) - } - - return allErrors -} - -// ListGroups objects of the peers -func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } - account, err := am.Store.GetAccount(ctx, accountID) + if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete); err != nil { + return fmt.Errorf("failed to delete group: %w", err) + } + return nil + }) if err != nil { - return nil, err + return err } - groups := make([]*nbgroup.Group, 0, len(account.Groups)) - for _, item := range account.Groups { - groups = append(groups, item) + for _, group := range deletedGroups { + am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta()) } - return groups, nil + return allErrors } // GroupAddPeer appends peer to the group func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) if err != nil { return err } - group, ok := account.Groups[groupID] - if !ok { - return status.Errorf(status.NotFound, "group with ID %s not found", groupID) - } - add := true for _, itemID := range group.Peers { if itemID == peerID { @@ -334,13 +349,27 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr group.Peers = append(group.Peers, peerID) } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID}) + if err != nil { + return err + } + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil { + return fmt.Errorf("failed to save group: %w", err) + } + return nil + }) + if err != nil { return err } - if areGroupChangesAffectPeers(account, []string{group.ID}) { - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) } return nil @@ -348,41 +377,55 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr // GroupDeletePeer removes peer from the group func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) if err != nil { return err } - group, ok := account.Groups[groupID] - if !ok { - return status.Errorf(status.NotFound, "group with ID %s not found", groupID) - } - - account.Network.IncSerial() + updated := false for i, itemID := range group.Peers { if itemID == peerID { group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) - if err := am.Store.SaveAccount(ctx, account); err != nil { - return err - } + updated = true + break + } + } + + if !updated { + return nil + } + + updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID}) + if err != nil { + return err + } + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil { + return fmt.Errorf("failed to save group: %w", err) } + return nil + }) + if err != nil { + return err } - if areGroupChangesAffectPeers(account, []string{group.ID}) { - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) } return nil } -func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error { +func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group *nbgroup.Group, userID string) error { // disable a deleting integration group if the initiator is not an admin service user if group.Issued == nbgroup.GroupIssuedIntegration { - executingUser := account.Users[userID] - if executingUser == nil { + executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { return status.Errorf(status.NotFound, "user not found") } if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { @@ -390,32 +433,42 @@ func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) } } - if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked { + if isLinked, linkedRoute := am.isGroupLinkedToRoute(ctx, group.AccountID, group.ID); isLinked { return &GroupLinkError{"route", string(linkedRoute.NetID)} } - if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked { + if isLinked, linkedDns := am.isGroupLinkedToDns(ctx, group.AccountID, group.ID); isLinked { return &GroupLinkError{"name server groups", linkedDns.Name} } - if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked { + if isLinked, linkedPolicy := am.isGroupLinkedToPolicy(ctx, group.AccountID, group.ID); isLinked { return &GroupLinkError{"policy", linkedPolicy.Name} } - if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked { + if isLinked, linkedSetupKey := am.isGroupLinkedToSetupKey(ctx, group.AccountID, group.ID); isLinked { return &GroupLinkError{"setup key", linkedSetupKey.Name} } - if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked { + if isLinked, linkedUser := am.isGroupLinkedToUser(ctx, group.AccountID, group.ID); isLinked { return &GroupLinkError{"user", linkedUser.Id} } - if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) { + dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID) + if err != nil { + return err + } + + if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) { return &GroupLinkError{"disabled DNS management groups", group.Name} } - if account.Settings.Extra != nil { - if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) { + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID) + if err != nil { + return err + } + + if settings.Extra != nil { + if slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) { return &GroupLinkError{"integrated validator", group.Name} } } @@ -424,17 +477,30 @@ func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) } // isGroupLinkedToRoute checks if a group is linked to any route in the account. -func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) { +func (am *DefaultAccountManager) isGroupLinkedToRoute(ctx context.Context, accountID string, groupID string) (bool, *route.Route) { + routes, err := am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err) + return false, nil + } + for _, r := range routes { if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { return true, r } } + return false, nil } // isGroupLinkedToPolicy checks if a group is linked to any policy in the account. -func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { +func (am *DefaultAccountManager) isGroupLinkedToPolicy(ctx context.Context, accountID string, groupID string) (bool, *Policy) { + policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err) + return false, nil + } + for _, policy := range policies { for _, rule := range policy.Rules { if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) { @@ -446,7 +512,13 @@ func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { } // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. -func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) { +func (am *DefaultAccountManager) isGroupLinkedToDns(ctx context.Context, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { + nameServerGroups, err := am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err) + return false, nil + } + for _, dns := range nameServerGroups { for _, g := range dns.Groups { if g == groupID { @@ -454,11 +526,18 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou } } } + return false, nil } // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. -func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) { +func (am *DefaultAccountManager) isGroupLinkedToSetupKey(ctx context.Context, accountID string, groupID string) (bool, *SetupKey) { + setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err) + return false, nil + } + for _, setupKey := range setupKeys { if slices.Contains(setupKey.AutoGroups, groupID) { return true, setupKey @@ -468,7 +547,13 @@ func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bo } // isGroupLinkedToUser checks if a group is linked to any user in the account. -func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) { +func (am *DefaultAccountManager) isGroupLinkedToUser(ctx context.Context, accountID string, groupID string) (bool, *User) { + users, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err) + return false, nil + } + for _, user := range users { if slices.Contains(user.AutoGroups, groupID) { return true, user @@ -477,6 +562,69 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) { return false, nil } +// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers. +func (am *DefaultAccountManager) areGroupChangesAffectPeers(ctx context.Context, accountID string, groupIDs []string) (bool, error) { + if len(groupIDs) == 0 { + return false, nil + } + + dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return false, err + } + + for _, groupID := range groupIDs { + if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) { + return true, nil + } + if linked, _ := am.isGroupLinkedToDns(ctx, accountID, groupID); linked { + return true, nil + } + if linked, _ := am.isGroupLinkedToPolicy(ctx, accountID, groupID); linked { + return true, nil + } + if linked, _ := am.isGroupLinkedToRoute(ctx, accountID, groupID); linked { + return true, nil + } + } + + return false, nil +} + +// isGroupLinkedToRoute checks if a group is linked to any route in the account. +func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) { + for _, r := range routes { + if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { + return true, r + } + } + return false, nil +} + +// isGroupLinkedToPolicy checks if a group is linked to any policy in the account. +func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { + for _, policy := range policies { + for _, rule := range policy.Rules { + if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) { + return true, policy + } + } + } + return false, nil +} + +// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. +func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) { + for _, dns := range nameServerGroups { + for _, g := range dns.Groups { + if g == groupID { + return true, dns + } + } + } + return false, nil +} + // anyGroupHasPeers checks if any of the given groups in the account have peers. func anyGroupHasPeers(account *Account, groupIDs []string) bool { for _, groupID := range groupIDs { diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 99e6b204c2b..0c70b702a01 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -52,25 +52,22 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con return am.Store.SaveAccount(ctx, a) } -func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) { - if len(groups) == 0 { +func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) { + if len(groupIDs) == 0 { return true, nil } - accountsGroups, err := am.ListGroups(ctx, accountId) - if err != nil { - return false, err - } - for _, group := range groups { - var found bool - for _, accountGroup := range accountsGroups { - if accountGroup.ID == group { - found = true - break + + err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + for _, groupID := range groupIDs { + _, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + if err != nil { + return err } } - if !found { - return false, nil - } + return nil + }) + if err != nil { + return false, err } return true, nil diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index d7139bb2a5f..aa6a47b152e 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -45,7 +45,6 @@ type MockAccountManager struct { SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error - ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error) GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error @@ -354,14 +353,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented") } -// ListGroups mock implementation of ListGroups from server.AccountManager interface -func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) { - if am.ListGroupsFunc != nil { - return am.ListGroupsFunc(ctx, accountID) - } - return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented") -} - // GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { if am.GroupAddPeerFunc != nil { diff --git a/management/server/sql_store.go b/management/server/sql_store.go index a11370e4f9d..506142453e6 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -614,11 +614,11 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre return &user, nil } -func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) { +func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) { startTime := time.Now() var users []*User - result := s.db.Find(&users, accountIDCondition, accountID) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") @@ -1240,10 +1240,27 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro return nil } -func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { +// GetPeerByID retrieves a peer by its ID and account ID. +func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) { + var peer *nbpeer.Peer + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&peer, accountAndIDQueryCondition, accountID, peerID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "peer not found") + } + log.WithContext(ctx).Errorf("failed to get peer from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peer from store") + } + + return peer, nil +} + +func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { startTime := time.Now() - result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { if errors.Is(result.Error, context.Canceled) { return status.NewStoreContextCanceledError(time.Since(startTime)) @@ -1336,42 +1353,82 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength } // GetGroupByID retrieves a group by ID and account ID. -func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) { - return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID) +func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) { + var group *nbgroup.Group + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&group, accountAndIDQueryCondition, accountID, groupID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "group not found") + } + log.WithContext(ctx).Errorf("failed to get group from store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get group from store") + } + + return group, nil } // GetGroupByName retrieves a group by name and account ID. -func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) { +func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) { var group nbgroup.Group // TODO: This fix is accepted for now, but if we need to handle this more frequently // we may need to reconsider changing the types. - query := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations) + query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations) if s.storeEngine == PostgresStoreEngine { query = query.Order("json_array_length(peers::json) DESC") } else { query = query.Order("json_array_length(peers) DESC") } - result := query.First(&group, "name = ? and account_id = ?", groupName, accountID) + result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName) if err := result.Error; err != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "group not found") } - return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error) + log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get group by name from store") } return &group, nil } // SaveGroup saves a group to the store. func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) if result.Error != nil { - return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error) + log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save group to store") } return nil } +// DeleteGroup deletes a group from the database. +func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&nbgroup.Group{}, accountAndIDQueryCondition, accountID, groupID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error) + return status.Errorf(status.Internal, "failed to delete group from store") + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "group not found") + } + + return nil +} + +// DeleteGroups deletes groups from the database. +func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error { + result := s.db.Clauses(clause.Locking{Strength: string(strength)}). + Delete(&nbgroup.Group{}, " account_id = ? AND id IN ?", accountID, groupIDs) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error) + } + + return nil +} + // GetAccountPolicies retrieves policies for an account. func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) diff --git a/management/server/store.go b/management/server/store.go index 73c9ef6a692..cb3c533dd09 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -62,7 +62,7 @@ type Store interface { GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) - GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) + GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) SaveUsers(accountID string, users map[string]*User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error @@ -75,6 +75,8 @@ type Store interface { GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error + DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error + DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) @@ -89,6 +91,7 @@ type Store interface { AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) + GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error @@ -107,7 +110,7 @@ type Store interface { GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) - IncrementNetworkSerial(ctx context.Context, accountId string) error + IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) GetInstallationID() string From 8126d953166ddfa79950469f42d0a8dc5084ce71 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 8 Nov 2024 18:58:04 +0300 Subject: [PATCH 08/28] refactor GetGroupByID and add NewGroupNotFoundError Signed-off-by: bcmmbaga --- management/server/account.go | 4 ++-- management/server/setupkey.go | 2 +- management/server/sql_store.go | 8 ++++---- management/server/status/error.go | 5 +++++ 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 2b18c344101..2902bc9521c 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2100,7 +2100,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range addNewGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { @@ -2113,7 +2113,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range removeOldGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 65d7796f1a0..a3330bba888 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -432,7 +432,7 @@ func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountI autoGroups := map[string]*nbgroup.Group{} for _, groupID := range autoGroupIDs { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) if err != nil { return nil, err } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 506142453e6..3707aa9cec1 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1196,7 +1196,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return status.Errorf(status.NotFound, "group not found for account") + return status.NewGroupNotFoundError(groupID) } if errors.Is(result.Error, context.Canceled) { return status.NewStoreContextCanceledError(time.Since(startTime)) @@ -1358,7 +1358,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&group, accountAndIDQueryCondition, accountID, groupID) if err := result.Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "group not found") + return nil, status.NewGroupNotFoundError(groupID) } log.WithContext(ctx).Errorf("failed to get group from store: %s", err) return nil, status.Errorf(status.Internal, "failed to get group from store") @@ -1383,7 +1383,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName) if err := result.Error; err != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "group not found") + return nil, status.NewGroupNotFoundError(groupName) } log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error) return nil, status.Errorf(status.Internal, "failed to get group by name from store") @@ -1411,7 +1411,7 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength } if result.RowsAffected == 0 { - return status.Errorf(status.NotFound, "group not found") + return status.NewGroupNotFoundError(groupID) } return nil diff --git a/management/server/status/error.go b/management/server/status/error.go index 5a75c94b1c1..00be347ada4 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -135,3 +135,8 @@ func NewStoreContextCanceledError(duration time.Duration) error { func NewInvalidKeyIDError() error { return Errorf(InvalidArgument, "invalid key ID") } + +// NewGroupNotFoundError creates a new Error with NotFound type for a missing group +func NewGroupNotFoundError(groupID string) error { + return Errorf(NotFound, "group: %s not found", groupID) +} From ac05f69131651fded5f6a304b7dbe2b517a72b31 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 8 Nov 2024 18:58:19 +0300 Subject: [PATCH 09/28] fix tests Signed-off-by: bcmmbaga --- management/server/account_test.go | 12 +++++++----- management/server/route_test.go | 2 +- management/server/sql_store_test.go | 8 ++++---- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/management/server/account_test.go b/management/server/account_test.go index fdf004a3b8a..97e0d45f016 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1413,11 +1413,13 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - group := group.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, - } + }) + + require.NoError(t, err, "failed to save group") policy := Policy{ Enabled: true, @@ -1460,7 +1462,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { return } - if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil { + if err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupA"); err != nil { t.Errorf("delete group: %v", err) return } @@ -2714,7 +2716,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 0) - group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") + group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1") assert.NoError(t, err, "unable to get group") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") }) @@ -2734,7 +2736,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1) - group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") + group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1") assert.NoError(t, err, "unable to get group") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") }) diff --git a/management/server/route_test.go b/management/server/route_test.go index 4893e19b9f3..5c848f68c7b 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1091,7 +1091,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { require.NoError(t, err) assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") - groups, err := am.ListGroups(context.Background(), account.Id) + groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id) require.NoError(t, err) var groupHA1, groupHA2 *nbgroup.Group for _, group := range groups { diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 3f3b2a453d4..20409798b0e 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1181,7 +1181,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { t.Fatal("failed to save group") return err } - group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID) + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.AccountID, group.ID) if err != nil { t.Fatal("failed to get group") return err @@ -1201,7 +1201,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" account, err := store.GetAccount(context.Background(), accountID) require.NoError(t, err) - users, err := store.GetAccountUsers(context.Background(), accountID) + users, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID) require.NoError(t, err) require.Len(t, users, len(account.Users)) } @@ -1260,9 +1260,9 @@ func TestSqlite_GetGroupByName(t *testing.T) { } accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, "All", accountID) + group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All") require.NoError(t, err) - require.Equal(t, "All", group.Name) + require.True(t, group.IsGroupAll()) } func Test_DeleteSetupKeySuccessfully(t *testing.T) { From 7100be83cdd002c25b2bf824687f14f1b183f770 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Sat, 9 Nov 2024 01:14:30 +0300 Subject: [PATCH 10/28] Add AddPeer and RemovePeer methods to Group struct Signed-off-by: bcmmbaga --- management/server/group/group.go | 29 +++++++++ management/server/group/group_test.go | 90 +++++++++++++++++++++++++++ 2 files changed, 119 insertions(+) create mode 100644 management/server/group/group_test.go diff --git a/management/server/group/group.go b/management/server/group/group.go index e98e5ecc4b5..bb0f5b7b6e2 100644 --- a/management/server/group/group.go +++ b/management/server/group/group.go @@ -54,3 +54,32 @@ func (g *Group) HasPeers() bool { func (g *Group) IsGroupAll() bool { return g.Name == "All" } + +// AddPeer adds peerID to Peers if not already present, +// returning true if added. +func (g *Group) AddPeer(peerID string) bool { + if peerID == "" { + return false + } + + for _, itemID := range g.Peers { + if itemID == peerID { + return false + } + } + + g.Peers = append(g.Peers, peerID) + return true +} + +// RemovePeer removes peerID from Peers if present, +// returning true if removed. +func (g *Group) RemovePeer(peerID string) bool { + for i, itemID := range g.Peers { + if itemID == peerID { + g.Peers = append(g.Peers[:i], g.Peers[i+1:]...) + return true + } + } + return false +} diff --git a/management/server/group/group_test.go b/management/server/group/group_test.go new file mode 100644 index 00000000000..cb002f8d9e1 --- /dev/null +++ b/management/server/group/group_test.go @@ -0,0 +1,90 @@ +package group + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAddPeer(t *testing.T) { + t.Run("add new peer to empty slice", func(t *testing.T) { + group := &Group{Peers: []string{}} + peerID := "peer1" + assert.True(t, group.AddPeer(peerID)) + assert.Contains(t, group.Peers, peerID) + }) + + t.Run("add new peer to nil slice", func(t *testing.T) { + group := &Group{Peers: nil} + peerID := "peer1" + assert.True(t, group.AddPeer(peerID)) + assert.Contains(t, group.Peers, peerID) + }) + + t.Run("add new peer to non-empty slice", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "peer3" + assert.True(t, group.AddPeer(peerID)) + assert.Contains(t, group.Peers, peerID) + }) + + t.Run("add duplicate peer", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "peer1" + assert.False(t, group.AddPeer(peerID)) + assert.Equal(t, 2, len(group.Peers)) + }) + + t.Run("add empty peer", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "" + assert.False(t, group.AddPeer(peerID)) + assert.Equal(t, 2, len(group.Peers)) + }) +} + +func TestRemovePeer(t *testing.T) { + t.Run("remove existing peer from slice", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2", "peer3"}} + peerID := "peer2" + assert.True(t, group.RemovePeer(peerID)) + assert.NotContains(t, group.Peers, peerID) + assert.Equal(t, 2, len(group.Peers)) + }) + + t.Run("remove peer from empty slice", func(t *testing.T) { + group := &Group{Peers: []string{}} + peerID := "peer1" + assert.False(t, group.RemovePeer(peerID)) + assert.Equal(t, 0, len(group.Peers)) + }) + + t.Run("remove peer from nil slice", func(t *testing.T) { + group := &Group{Peers: nil} + peerID := "peer1" + assert.False(t, group.RemovePeer(peerID)) + assert.Nil(t, group.Peers) + }) + + t.Run("remove non-existent peer", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "peer3" + assert.False(t, group.RemovePeer(peerID)) + assert.Equal(t, 2, len(group.Peers)) + }) + + t.Run("remove peer from single-item slice", func(t *testing.T) { + group := &Group{Peers: []string{"peer1"}} + peerID := "peer1" + assert.True(t, group.RemovePeer(peerID)) + assert.Equal(t, 0, len(group.Peers)) + assert.NotContains(t, group.Peers, peerID) + }) + + t.Run("remove empty peer", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "" + assert.False(t, group.RemovePeer(peerID)) + assert.Equal(t, 2, len(group.Peers)) + }) +} From 6dc185e141c4e10c64c8879f54a8338fd4e4c01d Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Sat, 9 Nov 2024 01:16:03 +0300 Subject: [PATCH 11/28] Preserve store engine in SqlStore transactions Signed-off-by: bcmmbaga --- management/server/sql_store.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index df0f2b3178b..8a0f432e6ae 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1116,7 +1116,8 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor func (s *SqlStore) withTx(tx *gorm.DB) Store { return &SqlStore{ - db: tx, + db: tx, + storeEngine: s.storeEngine, } } From bdeb95c58c2081b0a77776692398bc6c20be2b60 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Sat, 9 Nov 2024 01:17:01 +0300 Subject: [PATCH 12/28] Run groups ops in transaction Signed-off-by: bcmmbaga --- management/server/account.go | 4 +- management/server/group.go | 384 ++++++++++++++--------------------- management/server/peer.go | 18 +- 3 files changed, 173 insertions(+), 233 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 2902bc9521c..043b797ab41 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2126,12 +2126,12 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } if settings.GroupsPropagationEnabled { - removedGroupAffectsPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, removeOldGroups) + removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups) if err != nil { return err } - newGroupsAffectsPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, addNewGroups) + newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups) if err != nil { return err } diff --git a/management/server/group.go b/management/server/group.go index da4c0fb9415..c49bb247186 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -79,7 +79,7 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI // SaveGroups adds new groups to the account. // Note: This function does not acquire the global lock. // It is the caller's responsibility to ensure proper locking is in place before invoking this method. -func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error { +func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err @@ -89,66 +89,35 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user return status.NewUserNotPartOfAccountError() } - var ( - eventsToStore []func() - groupsToSave []*nbgroup.Group - ) - - for _, newGroup := range newGroups { - if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { - return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) - } + var eventsToStore []func() + var groupsToSave []*nbgroup.Group + var updateAccountPeers bool - if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { - existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name) - if err != nil { - s, ok := status.FromError(err) - if !ok || s.ErrorType != status.NotFound { - return err - } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + groupIDs := make([]string, 0, len(groups)) + for _, newGroup := range groups { + if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { + return err } - // Avoid duplicate groups only for the API issued groups. - // Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of. - if existingGroup != nil { - return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name) - } + newGroup.AccountID = accountID + groupsToSave = append(groupsToSave, newGroup) + groupIDs = append(groupIDs, newGroup.ID) - newGroup.ID = xid.New().String() + events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) + eventsToStore = append(eventsToStore, events...) } - for _, peerID := range newGroup.Peers { - if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID); err != nil { - return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) - } + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs) + if err != nil { + return err } - newGroup.AccountID = accountID - groupsToSave = append(groupsToSave, newGroup) - - events := am.prepareGroupEvents(ctx, userID, accountID, newGroup) - eventsToStore = append(eventsToStore, events...) - } - - newGroupIDs := make([]string, 0, len(newGroups)) - for _, newGroup := range newGroups { - newGroupIDs = append(newGroupIDs, newGroup.ID) - } - - updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, newGroupIDs) - if err != nil { - return err - } - - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return err } - if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave); err != nil { - return fmt.Errorf("failed to save groups: %w", err) - } - return nil + return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave) }) if err != nil { return err @@ -166,13 +135,13 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } // prepareGroupEvents prepares a list of event functions to be stored. -func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup *nbgroup.Group) []func() { +func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() { var eventsToStore []func() addedPeers := make([]string, 0) removedPeers := make([]string, 0) - oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID) + oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID) if err == nil && oldGroup != nil { addedPeers = difference(newGroup.Peers, oldGroup.Peers) removedPeers = difference(oldGroup.Peers, newGroup.Peers) @@ -184,36 +153,34 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID } for _, peerID := range addedPeers { - peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + peer, err := transaction.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID) if err != nil { - log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID) + log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: %v", peerID, err) continue } - peerCopy := peer // copy to avoid closure issues + meta := map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer, - map[string]any{ - "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(), - "peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()), - }) + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta) }) } for _, peerID := range removedPeers { - peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + peer, err := transaction.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID) if err != nil { - log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID) + log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: %v", peerID, err) continue } - peerCopy := peer // copy to avoid closure issues + meta := map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer, - map[string]any{ - "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(), - "peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()), - }) + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta) }) } @@ -246,28 +213,27 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use return status.NewUserNotPartOfAccountError() } - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) - if err != nil { - return err - } + var group *nbgroup.Group - if group.Name == "All" { - return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") - } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + group, err = transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) + if err != nil { + return err + } - if err = am.validateDeleteGroup(ctx, group, userID); err != nil { - return err - } + if group.IsGroupAll() { + return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") + } - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = validateDeleteGroup(ctx, transaction, group, userID); err != nil { return err } - if err = transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID); err != nil { - return fmt.Errorf("failed to delete group: %w", err) + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err } - return nil + + return transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID) }) if err != nil { return err @@ -279,6 +245,11 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use } // DeleteGroups deletes groups from an account. +// Note: This function does not acquire the global lock. +// It is the caller's responsibility to ensure proper locking is in place before invoking this method. +// +// If an error occurs while deleting a group, the function skips it and continues deleting other groups. +// Errors are collected and returned at the end. func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { @@ -289,36 +260,31 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us return status.NewUserNotPartOfAccountError() } - var ( - allErrors error - groupIDsToDelete []string - deletedGroups []*nbgroup.Group - ) + var allErrors error + var groupIDsToDelete []string + var deletedGroups []*nbgroup.Group - for _, groupID := range groupIDs { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) - if err != nil { - continue - } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + for _, groupID := range groupIDs { + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) + if err != nil { + continue + } - if err := am.validateDeleteGroup(ctx, group, userID); err != nil { - allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) - continue - } + if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil { + allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) + continue + } - groupIDsToDelete = append(groupIDsToDelete, groupID) - deletedGroups = append(deletedGroups, group) - } + groupIDsToDelete = append(groupIDsToDelete, groupID) + deletedGroups = append(deletedGroups, group) + } - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return err } - if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete); err != nil { - return fmt.Errorf("failed to delete group: %w", err) - } - return nil + return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete) }) if err != nil { return err @@ -333,36 +299,30 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us // GroupAddPeer appends peer to the group func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) - if err != nil { - return err - } + var group *nbgroup.Group + var updateAccountPeers bool + var err error - add := true - for _, itemID := range group.Peers { - if itemID == peerID { - add = false - break + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + if err != nil { + return err } - } - if add { - group.Peers = append(group.Peers, peerID) - } - updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID}) - if err != nil { - return err - } + if updated := group.AddPeer(peerID); !updated { + return nil + } - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) + if err != nil { return err } - if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil { - return fmt.Errorf("failed to save group: %w", err) + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err } - return nil + + return transaction.SaveGroup(ctx, LockingStrengthUpdate, group) }) if err != nil { return err @@ -377,38 +337,30 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr // GroupDeletePeer removes peer from the group func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) - if err != nil { - return err - } + var group *nbgroup.Group + var updateAccountPeers bool + var err error - updated := false - for i, itemID := range group.Peers { - if itemID == peerID { - group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) - updated = true - break + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + if err != nil { + return err } - } - if !updated { - return nil - } + if updated := group.RemovePeer(peerID); !updated { + return nil + } - updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID}) - if err != nil { - return err - } + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) + if err != nil { + return err + } - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return err } - if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil { - return fmt.Errorf("failed to save group: %w", err) - } - return nil + return transaction.SaveGroup(ctx, LockingStrengthUpdate, group) }) if err != nil { return err @@ -421,10 +373,43 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, return nil } -func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group *nbgroup.Group, userID string) error { +// validateNewGroup validates the new group for existence and required fields. +func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error { + if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { + return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) + } + + if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { + existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name) + if err != nil { + if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound { + return err + } + } + + // Prevent duplicate groups for API-issued groups. + // Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of. + if existingGroup != nil { + return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name) + } + + newGroup.ID = xid.New().String() + } + + for _, peerID := range newGroup.Peers { + _, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + if err != nil { + return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) + } + } + + return nil +} + +func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error { // disable a deleting integration group if the initiator is not an admin service user if group.Issued == nbgroup.GroupIssuedIntegration { - executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return status.Errorf(status.NotFound, "user not found") } @@ -433,27 +418,27 @@ func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group } } - if isLinked, linkedRoute := am.isGroupLinkedToRoute(ctx, group.AccountID, group.ID); isLinked { + if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"route", string(linkedRoute.NetID)} } - if isLinked, linkedDns := am.isGroupLinkedToDns(ctx, group.AccountID, group.ID); isLinked { + if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"name server groups", linkedDns.Name} } - if isLinked, linkedPolicy := am.isGroupLinkedToPolicy(ctx, group.AccountID, group.ID); isLinked { + if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"policy", linkedPolicy.Name} } - if isLinked, linkedSetupKey := am.isGroupLinkedToSetupKey(ctx, group.AccountID, group.ID); isLinked { + if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"setup key", linkedSetupKey.Name} } - if isLinked, linkedUser := am.isGroupLinkedToUser(ctx, group.AccountID, group.ID); isLinked { + if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"user", linkedUser.Id} } - dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID) + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID) if err != nil { return err } @@ -462,7 +447,7 @@ func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group return &GroupLinkError{"disabled DNS management groups", group.Name} } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID) + settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID) if err != nil { return err } @@ -477,8 +462,8 @@ func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group } // isGroupLinkedToRoute checks if a group is linked to any route in the account. -func (am *DefaultAccountManager) isGroupLinkedToRoute(ctx context.Context, accountID string, groupID string) (bool, *route.Route) { - routes, err := am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) { + routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err) return false, nil @@ -494,8 +479,8 @@ func (am *DefaultAccountManager) isGroupLinkedToRoute(ctx context.Context, accou } // isGroupLinkedToPolicy checks if a group is linked to any policy in the account. -func (am *DefaultAccountManager) isGroupLinkedToPolicy(ctx context.Context, accountID string, groupID string) (bool, *Policy) { - policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) { + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err) return false, nil @@ -512,8 +497,8 @@ func (am *DefaultAccountManager) isGroupLinkedToPolicy(ctx context.Context, acco } // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. -func (am *DefaultAccountManager) isGroupLinkedToDns(ctx context.Context, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { - nameServerGroups, err := am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { + nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err) return false, nil @@ -531,8 +516,8 @@ func (am *DefaultAccountManager) isGroupLinkedToDns(ctx context.Context, account } // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. -func (am *DefaultAccountManager) isGroupLinkedToSetupKey(ctx context.Context, accountID string, groupID string) (bool, *SetupKey) { - setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) { + setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err) return false, nil @@ -547,8 +532,8 @@ func (am *DefaultAccountManager) isGroupLinkedToSetupKey(ctx context.Context, ac } // isGroupLinkedToUser checks if a group is linked to any user in the account. -func (am *DefaultAccountManager) isGroupLinkedToUser(ctx context.Context, accountID string, groupID string) (bool, *User) { - users, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) { + users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err) return false, nil @@ -563,12 +548,12 @@ func (am *DefaultAccountManager) isGroupLinkedToUser(ctx context.Context, accoun } // areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers. -func (am *DefaultAccountManager) areGroupChangesAffectPeers(ctx context.Context, accountID string, groupIDs []string) (bool, error) { +func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) { if len(groupIDs) == 0 { return false, nil } - dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) if err != nil { return false, err } @@ -577,13 +562,13 @@ func (am *DefaultAccountManager) areGroupChangesAffectPeers(ctx context.Context, if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) { return true, nil } - if linked, _ := am.isGroupLinkedToDns(ctx, accountID, groupID); linked { + if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked { return true, nil } - if linked, _ := am.isGroupLinkedToPolicy(ctx, accountID, groupID); linked { + if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked { return true, nil } - if linked, _ := am.isGroupLinkedToRoute(ctx, accountID, groupID); linked { + if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked { return true, nil } } @@ -591,40 +576,6 @@ func (am *DefaultAccountManager) areGroupChangesAffectPeers(ctx context.Context, return false, nil } -// isGroupLinkedToRoute checks if a group is linked to any route in the account. -func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) { - for _, r := range routes { - if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { - return true, r - } - } - return false, nil -} - -// isGroupLinkedToPolicy checks if a group is linked to any policy in the account. -func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { - for _, policy := range policies { - for _, rule := range policy.Rules { - if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) { - return true, policy - } - } - } - return false, nil -} - -// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. -func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) { - for _, dns := range nameServerGroups { - for _, g := range dns.Groups { - if g == groupID { - return true, dns - } - } - } - return false, nil -} - // anyGroupHasPeers checks if any of the given groups in the account have peers. func anyGroupHasPeers(account *Account, groupIDs []string) bool { for _, groupID := range groupIDs { @@ -634,22 +585,3 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool { } return false } - -func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool { - for _, groupID := range groupIDs { - if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) { - return true - } - if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked { - return true - } - if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked { - return true - } - if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked { - return true - } - } - - return false -} diff --git a/management/server/peer.go b/management/server/peer.go index 994cc02879c..33f27d8c7e0 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -331,7 +331,10 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - updateAccountPeers := isPeerInActiveGroup(account, peerID) + updateAccountPeers, err := am.isPeerInActiveGroup(ctx, account, peerID) + if err != nil { + return err + } err = am.deletePeers(ctx, account, []string{peerID}, userID) if err != nil { @@ -594,9 +597,14 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s if err != nil { return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err) } - groupsToAdd = append(groupsToAdd, allGroup.ID) - if areGroupChangesAffectPeers(account, groupsToAdd) { + + newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, groupsToAdd) + if err != nil { + return nil, nil, nil, err + } + + if newGroupsAffectsPeers { am.updateAccountPeers(ctx, accountID) } @@ -1033,12 +1041,12 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} { // IsPeerInActiveGroup checks if the given peer is part of a group that is used // in an active DNS, route, or ACL configuration. -func isPeerInActiveGroup(account *Account, peerID string) bool { +func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *Account, peerID string) (bool, error) { peerGroupIDs := make([]string, 0) for _, group := range account.Groups { if slices.Contains(group.Peers, peerID) { peerGroupIDs = append(peerGroupIDs, group.ID) } } - return areGroupChangesAffectPeers(account, peerGroupIDs) + return areGroupChangesAffectPeers(ctx, am.Store, account.Id, peerGroupIDs) } From 3ed8b9cee93e7d45f3d27210606536c06169ab06 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Sat, 9 Nov 2024 01:48:28 +0300 Subject: [PATCH 13/28] fix missing group removed from setup key activity Signed-off-by: bcmmbaga --- management/server/setupkey.go | 95 +++++++++++++++++++---------------- 1 file changed, 53 insertions(+), 42 deletions(-) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 65d7796f1a0..2e8230d1ccb 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -12,8 +12,8 @@ import ( "github.com/google/uuid" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" + log "github.com/sirupsen/logrus" ) const ( @@ -236,19 +236,21 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s return nil, status.NewUserNotPartOfAccountError() } - var groups map[string]*nbgroup.Group var setupKey *SetupKey var plainKey string + var eventsToStore []func() err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - groups, err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups) - if err != nil { + if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil { return err } setupKey, plainKey = GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) setupKey.AccountID = accountID + events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, autoGroups, nil, setupKey) + eventsToStore = append(eventsToStore, events...) + return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, setupKey) }) if err != nil { @@ -256,13 +258,8 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s } am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta()) - - for _, g := range setupKey.AutoGroups { - group, ok := groups[g] - if ok { - am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.GroupAddedToSetupKey, - map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": setupKey.Name}) - } + for _, storeEvent := range eventsToStore { + storeEvent() } // for the creation return the plain key to the caller @@ -292,13 +289,12 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return nil, status.NewUserNotPartOfAccountError() } - var groups map[string]*nbgroup.Group var oldKey *SetupKey var newKey *SetupKey + var eventsToStore []func() err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - groups, err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups) - if err != nil { + if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil { return err } @@ -314,6 +310,12 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str newKey.Revoked = keyToSave.Revoked newKey.UpdatedAt = time.Now().UTC() + addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups) + removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups) + + events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups, oldKey) + eventsToStore = append(eventsToStore, events...) + return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, newKey) }) if err != nil { @@ -324,26 +326,9 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str am.StoreEvent(ctx, userID, newKey.Id, accountID, activity.SetupKeyRevoked, newKey.EventMeta()) } - defer func() { - addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups) - removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups) - - for _, g := range removedGroups { - group, ok := groups[g] - if ok { - am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupRemovedFromSetupKey, - map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name}) - } - } - - for _, g := range addedGroups { - group, ok := groups[g] - if ok { - am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupAddedToSetupKey, - map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name}) - } - } - }() + for _, storeEvent := range eventsToStore { + storeEvent() + } return newKey, nil } @@ -412,7 +397,7 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, var deletedSetupKey *SetupKey err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) + deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID) if err != nil { return err } @@ -428,20 +413,46 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, return nil } -func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) (map[string]*nbgroup.Group, error) { - autoGroups := map[string]*nbgroup.Group{} - +func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) error { for _, groupID := range autoGroupIDs { group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) if err != nil { - return nil, err + return err } if group.IsGroupAll() { - return nil, status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key") + return status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key") + } + } + + return nil +} + +// prepareSetupKeyEvents prepares a list of event functions to be stored. +func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string, key *SetupKey) []func() { + var eventsToStore []func() + + for _, g := range removedGroups { + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + if err != nil { + log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: %v", g, err) + continue + } + + meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name} + am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupRemovedFromSetupKey, meta) + } + + for _, g := range addedGroups { + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + if err != nil { + log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: %v", g, err) + continue } - autoGroups[group.ID] = group + + meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name} + am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupAddedToSetupKey, meta) } - return autoGroups, nil + return eventsToStore } From 871500c5cc0523cfb5a0032a2a56bfd10366edf3 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Sat, 9 Nov 2024 01:52:09 +0300 Subject: [PATCH 14/28] fix merge Signed-off-by: bcmmbaga --- management/server/setupkey.go | 4 ++-- management/server/store.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 160e934482f..d6e92fe3ab1 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -433,7 +433,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran var eventsToStore []func() for _, g := range removedGroups { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: %v", g, err) continue @@ -444,7 +444,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran } for _, g := range addedGroups { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: %v", g, err) continue diff --git a/management/server/store.go b/management/server/store.go index cb3c533dd09..68b57204b55 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -71,7 +71,7 @@ type Store interface { DeleteTokenID2UserIDIndex(tokenID string) error GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) - GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) + GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error From 174e07fefda60632effa26df6e04e53f09eb1bbe Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 11 Nov 2024 12:37:19 +0300 Subject: [PATCH 15/28] Refactor posture checks to remove get and save account Signed-off-by: bcmmbaga --- management/server/account.go | 2 +- .../server/http/posture_checks_handler.go | 3 +- management/server/mock_server/account_mock.go | 6 +- management/server/posture/checks.go | 6 - management/server/posture_checks.go | 303 +++++++++++------- management/server/posture_checks_test.go | 211 +++++++----- management/server/sql_store.go | 54 +++- management/server/status/error.go | 5 + management/server/store.go | 4 +- 9 files changed, 377 insertions(+), 217 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 043b797ab41..8ebbb0fa0a0 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -139,7 +139,7 @@ type AccountManager interface { HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error + SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManager() idp.Manager diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/posture_checks_handler.go index 1d020e9bcb7..2c820429278 100644 --- a/management/server/http/posture_checks_handler.go +++ b/management/server/http/posture_checks_handler.go @@ -169,7 +169,8 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http. return } - if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil { + postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks) + if err != nil { util.WriteError(r.Context(), err, w) return } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index aa6a47b152e..673ed33bb9b 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -96,7 +96,7 @@ type MockAccountManager struct { HasConnectedChannelFunc func(peerID string) bool GetExternalCacheManagerFunc func() server.ExternalCacheManager GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error + SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManagerFunc func() idp.Manager @@ -730,11 +730,11 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p } // SavePostureChecks mocks SavePostureChecks of the AccountManager interface -func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { +func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { if am.SavePostureChecksFunc != nil { return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks) } - return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented") + return nil, status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented") } // DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface diff --git a/management/server/posture/checks.go b/management/server/posture/checks.go index f2739dddf8d..b2f308d76e2 100644 --- a/management/server/posture/checks.go +++ b/management/server/posture/checks.go @@ -7,8 +7,6 @@ import ( "regexp" "github.com/hashicorp/go-version" - "github.com/rs/xid" - "github.com/netbirdio/netbird/management/server/http/api" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" @@ -172,10 +170,6 @@ func NewChecksFromAPIPostureCheckUpdate(source api.PostureCheckUpdate, postureCh } func buildPostureCheck(postureChecksID string, name string, description string, checks api.Checks) (*Checks, error) { - if postureChecksID == "" { - postureChecksID = xid.New().String() - } - postureChecks := Checks{ ID: postureChecksID, Name: name, diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 096cff3f5c9..d7b5a79a23b 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -2,16 +2,15 @@ package server import ( "context" + "fmt" "slices" "github.com/netbirdio/netbird/management/server/activity" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" -) - -const ( - errMsgPostureAdminOnly = "only users with admin power are allowed to view posture checks" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" ) func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { @@ -20,85 +19,104 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID return nil, err } - if !user.HasAdminPower() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID) -} + if !user.HasAdminPower() { + return nil, status.NewAdminPermissionError() + } -func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() + return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) +} - account, err := am.Store.GetAccount(ctx, accountID) +// SavePostureChecks saves a posture check. +func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { - return err + return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return err + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) + return nil, status.NewAdminPermissionError() } - if err := postureChecks.Validate(); err != nil { - return status.Errorf(status.InvalidArgument, err.Error()) //nolint - } + var updateAccountPeers bool + var isUpdate = postureChecks.ID != "" + var action = activity.PostureCheckCreated - exists, uniqName := am.savePostureChecks(account, postureChecks) + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil { + return err + } - // we do not allow create new posture checks with non uniq name - if !exists && !uniqName { - return status.Errorf(status.PreconditionFailed, "Posture check name should be unique") - } + if isUpdate { + updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID) + if err != nil { + return err + } - action := activity.PostureCheckCreated - if exists { - action = activity.PostureCheckUpdated - account.Network.IncSerial() - } + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err + action = activity.PostureCheckUpdated + } + + postureChecks.AccountID = accountID + return transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks) + }) + if err != nil { + return nil, err } am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) - if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) { + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } - return nil + return postureChecks, nil } +// DeletePostureChecks deletes a posture check by ID. func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - user, err := account.FindUser(userID) - if err != nil { - return err + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) + return status.NewAdminPermissionError() } - postureChecks, err := am.deletePostureChecks(account, postureChecksID) - if err != nil { - return err - } + var postureChecks *posture.Checks - if err = am.Store.SaveAccount(ctx, account); err != nil { + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + postureChecks, err = transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) + if err != nil { + return err + } + + if err = isPostureCheckLinkedToPolicy(ctx, transaction, postureChecksID, accountID); err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID) + }) + if err != nil { return err } @@ -107,132 +125,173 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun return nil } +// ListPostureChecks returns a list of posture checks. func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - if !user.HasAdminPower() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if !user.HasAdminPower() { + return nil, status.NewAdminPermissionError() } return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) } -func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) { - uniqName = true - for i, p := range account.PostureChecks { - if !exists && p.ID == postureChecks.ID { - account.PostureChecks[i] = postureChecks - exists = true +// getPeerPostureChecks returns the posture checks applied for a given peer. +func (am *DefaultAccountManager) getPeerPostureChecks(ctx context.Context, accountID string, peerID string) ([]*posture.Checks, error) { + peerPostureChecks := make(map[string]*posture.Checks) + + err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + postureChecks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + if len(postureChecks) == 0 { + return nil } - if p.Name == postureChecks.Name { - uniqName = false + + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + for _, policy := range policies { + if !policy.Enabled { + continue + } + + if err = addPolicyPostureChecks(ctx, transaction, accountID, peerID, policy, peerPostureChecks); err != nil { + return err + } } + + return nil + }) + if err != nil { + return nil, err + } + + return maps.Values(peerPostureChecks), nil +} + +// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers. +func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, accountID, postureCheckID string) (bool, error) { + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + return false, err } - if !exists { - account.PostureChecks = append(account.PostureChecks, postureChecks) + + for _, policy := range policies { + if slices.Contains(policy.SourcePostureChecks, postureCheckID) { + hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.ruleGroups()) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + } } - return + + return false, nil } -func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureChecksID string) (*posture.Checks, error) { - postureChecksIdx := -1 - for i, postureChecks := range account.PostureChecks { - if postureChecks.ID == postureChecksID { - postureChecksIdx = i - break +// validatePostureChecks validates the posture checks. +func validatePostureChecks(ctx context.Context, transaction Store, accountID string, postureChecks *posture.Checks) error { + if err := postureChecks.Validate(); err != nil { + return status.Errorf(status.InvalidArgument, err.Error()) //nolint + } + + // If the posture check already has an ID, verify its existence in the store. + if postureChecks.ID != "" { + if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil { + return err } + return nil } - if postureChecksIdx < 0 { - return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID) + + // For new posture checks, ensure no duplicates by name. + checks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + if err != nil { + return err } - // Check if posture check is linked to any policy - if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked { - return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name) + for _, check := range checks { + if check.Name == postureChecks.Name && check.ID != postureChecks.ID { + return status.Errorf(status.InvalidArgument, "posture checks with name %s already exists", postureChecks.Name) + } } - postureChecks := account.PostureChecks[postureChecksIdx] - account.PostureChecks = append(account.PostureChecks[:postureChecksIdx], account.PostureChecks[postureChecksIdx+1:]...) + postureChecks.ID = xid.New().String() - return postureChecks, nil + return nil } -// getPeerPostureChecks returns the posture checks applied for a given peer. -func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peer *nbpeer.Peer) []*posture.Checks { - peerPostureChecks := make(map[string]posture.Checks) +// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups. +func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error { + isInGroup, err := isPeerInPolicySourceGroups(ctx, transaction, accountID, peerID, policy) + if err != nil { + return err + } - if len(account.PostureChecks) == 0 { + if !isInGroup { return nil } - for _, policy := range account.Policies { - if !policy.Enabled { - continue - } - - if isPeerInPolicySourceGroups(peer.ID, account, policy) { - addPolicyPostureChecks(account, policy, peerPostureChecks) + for _, sourcePostureCheckID := range policy.SourcePostureChecks { + postureCheck, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, sourcePostureCheckID) + if err != nil { + return err } + peerPostureChecks[sourcePostureCheckID] = postureCheck } - postureChecksList := make([]*posture.Checks, 0, len(peerPostureChecks)) - for _, check := range peerPostureChecks { - checkCopy := check - postureChecksList = append(postureChecksList, &checkCopy) - } - - return postureChecksList + return nil } // isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. -func isPeerInPolicySourceGroups(peerID string, account *Account, policy *Policy) bool { +func isPeerInPolicySourceGroups(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy) (bool, error) { for _, rule := range policy.Rules { if !rule.Enabled { continue } for _, sourceGroup := range rule.Sources { - group, ok := account.Groups[sourceGroup] - if ok && slices.Contains(group.Peers, peerID) { - return true + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup) + if err != nil { + log.WithContext(ctx).Debugf("failed to check peer in policy source group: %v", err) + return false, fmt.Errorf("failed to check peer in policy source group: %w", err) } - } - } - return false -} - -func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks map[string]posture.Checks) { - for _, sourcePostureCheckID := range policy.SourcePostureChecks { - for _, postureCheck := range account.PostureChecks { - if postureCheck.ID == sourcePostureCheckID { - peerPostureChecks[sourcePostureCheckID] = *postureCheck + if slices.Contains(group.Peers, peerID) { + return true, nil } } } -} -func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) { - for _, policy := range account.Policies { - if slices.Contains(policy.SourcePostureChecks, postureChecksID) { - return true, policy - } - } return false, nil } -// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers. -func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool { - if !exists { - return false +// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy. +func isPostureCheckLinkedToPolicy(ctx context.Context, transaction Store, postureChecksID, accountID string) error { + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + return err } - isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID) - if !isLinked { - return false + for _, policy := range policies { + if slices.Contains(policy.SourcePostureChecks, postureChecksID) { + return status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name) + } } - return anyGroupHasPeers(account, linkedPolicy.ruleGroups()) + + return nil } diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index c63538b9d52..3c5c5fc79e6 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -7,6 +7,7 @@ import ( "github.com/rs/xid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/group" @@ -16,7 +17,6 @@ import ( const ( adminUserID = "adminUserID" regularUserID = "regularUserID" - postureCheckID = "existing-id" postureCheckName = "Existing check" ) @@ -33,7 +33,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { t.Run("Generic posture check flow", func(t *testing.T) { // regular users can not create checks - err := am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}) + _, err = am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}) assert.Error(t, err) // regular users cannot list check @@ -41,8 +41,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.Error(t, err) // should be possible to create posture check with uniq name - err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ - ID: postureCheckID, + postureCheck, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ Name: postureCheckName, Checks: posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ @@ -58,8 +57,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.Len(t, checks, 1) // should not be possible to create posture check with non uniq name - err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ - ID: "new-id", + _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ Name: postureCheckName, Checks: posture.ChecksDefinition{ GeoLocationCheck: &posture.GeoLocationCheck{ @@ -74,23 +72,20 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.Error(t, err) // admins can update posture checks - err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ - ID: postureCheckID, - Name: postureCheckName, - Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{ - MinVersion: "0.27.0", - }, + postureCheck.Checks = posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.27.0", }, - }) + } + _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheck) assert.NoError(t, err) // users should not be able to delete posture checks - err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, regularUserID) + err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, regularUserID) assert.Error(t, err) // admin should be able to delete posture checks - err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, adminUserID) + err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, adminUserID) assert.NoError(t, err) checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID) assert.NoError(t, err) @@ -150,9 +145,22 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) }) - postureCheck := posture.Checks{ - ID: "postureCheck", - Name: "postureCheck", + postureCheckA := &posture.Checks{ + Name: "postureCheckA", + AccountID: account.Id, + Checks: posture.ChecksDefinition{ + ProcessCheck: &posture.ProcessCheck{ + Processes: []posture.Process{ + {LinuxPath: "/usr/bin/netbird", MacPath: "/usr/local/bin/netbird"}, + }, + }, + }, + } + postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA) + require.NoError(t, err) + + postureCheckB := &posture.Checks{ + Name: "postureCheckB", AccountID: account.Id, Checks: posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ @@ -169,7 +177,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -187,12 +195,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -215,7 +223,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, + SourcePostureChecks: []string{postureCheckB.ID}, } // Linking posture check to policy should trigger update account peers and send peer update @@ -238,7 +246,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked posture checks should update account peers and send peer update t.Run("updating linked to posture check with peers", func(t *testing.T) { - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, @@ -255,7 +263,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -293,7 +301,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.DeletePostureChecks(context.Background(), account.Id, "postureCheck", userID) + err := manager.DeletePostureChecks(context.Background(), account.Id, postureCheckA.ID, userID) assert.NoError(t, err) select { @@ -303,7 +311,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } }) - err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) // Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update @@ -321,7 +329,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, + SourcePostureChecks: []string{postureCheckB.ID}, } err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) assert.NoError(t, err) @@ -332,12 +340,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -367,7 +375,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, + SourcePostureChecks: []string{postureCheckB.ID}, } err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) @@ -379,12 +387,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -409,7 +417,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, + SourcePostureChecks: []string{postureCheckB.ID}, } err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) assert.NoError(t, err) @@ -420,7 +428,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ ProcessCheck: &posture.ProcessCheck{ Processes: []posture.Process{ { @@ -429,7 +437,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -440,80 +448,123 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }) } -func TestArePostureCheckChangesAffectingPeers(t *testing.T) { - account := &Account{ - Policies: []*Policy{ - { - ID: "policyA", - Rules: []*PolicyRule{ - { - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupA"}, - }, - }, - SourcePostureChecks: []string{"checkA"}, - }, +func TestArePostureCheckChangesAffectPeers(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err, "failed to create account manager") + + account, err := initTestPostureChecksAccount(manager) + require.NoError(t, err, "failed to init testing account") + + groupA := &group.Group{ + ID: "groupA", + AccountID: account.Id, + Peers: []string{"peer1"}, + } + + groupB := &group.Group{ + ID: "groupB", + AccountID: account.Id, + Peers: []string{}, + } + err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB}) + require.NoError(t, err, "failed to save groups") + + postureCheckA := &posture.Checks{ + Name: "checkA", + AccountID: account.Id, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"}, }, - Groups: map[string]*group.Group{ - "groupA": { - ID: "groupA", - Peers: []string{"peer1"}, - }, - "groupB": { - ID: "groupB", - Peers: []string{}, - }, + } + postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckA) + require.NoError(t, err, "failed to save postureCheckA") + + postureCheckB := &posture.Checks{ + Name: "checkB", + AccountID: account.Id, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"}, }, - PostureChecks: []*posture.Checks{ - { - ID: "checkA", - }, + } + postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB) + require.NoError(t, err, "failed to save postureCheckB") + + policy := &Policy{ + ID: "policyA", + AccountID: account.Id, + Rules: []*PolicyRule{ { - ID: "checkB", + ID: "ruleA", + PolicyID: "policyA", + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, }, }, + SourcePostureChecks: []string{postureCheckA.ID}, } + err = manager.SavePolicy(context.Background(), account.Id, userID, policy, false) + require.NoError(t, err, "failed to save policy") + t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) { - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.True(t, result) }) t.Run("posture check exists but is not linked to any policy", func(t *testing.T) { - result := arePostureCheckChangesAffectingPeers(account, "checkB", true) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckB.ID) + require.NoError(t, err) assert.False(t, result) }) t.Run("posture check does not exist", func(t *testing.T) { - result := arePostureCheckChangesAffectingPeers(account, "unknown", false) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, "unknown") + require.NoError(t, err) assert.False(t, result) }) t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) { - account.Policies[0].Rules[0].Sources = []string{"groupB"} - account.Policies[0].Rules[0].Destinations = []string{"groupA"} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + policy.Rules[0].Sources = []string{"groupB"} + policy.Rules[0].Destinations = []string{"groupA"} + err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + require.NoError(t, err, "failed to update policy") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.True(t, result) }) t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) { - account.Policies[0].Rules[0].Sources = []string{"groupA"} - account.Policies[0].Rules[0].Destinations = []string{"groupB"} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + policy.Rules[0].Sources = []string{"groupA"} + policy.Rules[0].Destinations = []string{"groupB"} + err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + require.NoError(t, err, "failed to update policy") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.True(t, result) }) - t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { - account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"} - account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { + groupA.Peers = []string{} + err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA) + require.NoError(t, err, "failed to save groups") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.False(t, result) }) - t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { - account.Groups["groupA"].Peers = []string{} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { + policy.Rules[0].Sources = []string{"nonExistentGroup"} + policy.Rules[0].Destinations = []string{"nonExistentGroup"} + err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + require.NoError(t, err, "failed to update policy") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.False(t, result) }) } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 8a0f432e6ae..466d36aff92 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1257,12 +1257,60 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng // GetAccountPostureChecks retrieves posture checks for an account. func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { - return getRecords[*posture.Checks](s.db, lockStrength, accountID) + var postureChecks []*posture.Checks + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get posture checks from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get posture checks from store") + } + + return postureChecks, nil } // GetPostureChecksByID retrieves posture checks by their ID and account ID. -func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) { - return getRecordByID[posture.Checks](s.db, lockStrength, postureCheckID, accountID) +func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) (*posture.Checks, error) { + var postureCheck *posture.Checks + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&postureCheck, accountAndIDQueryCondition, accountID, postureChecksID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewPostureChecksNotFoundError(postureChecksID) + } + log.WithContext(ctx).Errorf("failed to get posture check from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get posture check from store") + } + + return postureCheck, nil +} + +// SavePostureChecks saves a posture checks to the database. +func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrDuplicatedKey) { + return status.Errorf(status.InvalidArgument, "name should be unique") + } + log.WithContext(ctx).Errorf("failed to save posture checks to store: %s", result.Error) + return status.Errorf(status.Internal, "failed to save posture checks to store") + } + + return nil +} + +// DeletePostureChecks deletes a posture checks from the database. +func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete posture checks from store: %s", result.Error) + return status.Errorf(status.Internal, "failed to delete posture checks from store") + } + + if result.RowsAffected == 0 { + return status.NewPostureChecksNotFoundError(postureChecksID) + } + + return nil } // GetAccountRoutes retrieves network routes for an account. diff --git a/management/server/status/error.go b/management/server/status/error.go index 00be347ada4..bdf5c754946 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -140,3 +140,8 @@ func NewInvalidKeyIDError() error { func NewGroupNotFoundError(groupID string) error { return Errorf(NotFound, "group: %s not found", groupID) } + +// NewPostureChecksNotFoundError creates a new Error with NotFound type for a missing posture checks +func NewPostureChecksNotFoundError(postureChecksID string) error { + return Errorf(NotFound, "posture checks: %s not found", postureChecksID) +} diff --git a/management/server/store.go b/management/server/store.go index 68b57204b55..7e258104558 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -83,7 +83,9 @@ type Store interface { GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) - GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) + GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error) + SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error + DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error From d54b6967ce28b07ff799c08a8d4d789b0dfde322 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 11 Nov 2024 12:38:34 +0300 Subject: [PATCH 16/28] fix refactor Signed-off-by: bcmmbaga --- management/server/dns.go | 2 +- management/server/group.go | 19 +++++++++++++++++-- management/server/nameserver.go | 10 +++++----- management/server/peer.go | 25 +++++++++++++++++++++---- management/server/policy.go | 6 +++--- management/server/route.go | 10 +++++----- 6 files changed, 52 insertions(+), 20 deletions(-) diff --git a/management/server/dns.go b/management/server/dns.go index 4551be5ab92..e52be601639 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -145,7 +145,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) } - if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) { + if am.anyGroupHasPeers(account, addedGroups) || am.anyGroupHasPeers(account, removedGroups) { am.updateAccountPeers(ctx, accountID) } diff --git a/management/server/group.go b/management/server/group.go index c49bb247186..ee42b0064a7 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -576,8 +576,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountI return false, nil } -// anyGroupHasPeers checks if any of the given groups in the account have peers. -func anyGroupHasPeers(account *Account, groupIDs []string) bool { +func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool { for _, groupID := range groupIDs { if group, exists := account.Groups[groupID]; exists && group.HasPeers() { return true @@ -585,3 +584,19 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool { } return false } + +// anyGroupHasPeers checks if any of the given groups in the account have peers. +func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) { + for _, groupID := range groupIDs { + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) + if err != nil { + return false, err + } + + if group.HasPeers() { + return true, nil + } + } + + return false, nil +} diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 957008714e5..9119a3dec72 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -70,7 +70,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco return nil, err } - if anyGroupHasPeers(account, newNSGroup.Groups) { + if am.anyGroupHasPeers(account, newNSGroup.Groups) { am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) @@ -105,7 +105,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return err } - if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { + if am.areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) @@ -135,7 +135,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco return err } - if anyGroupHasPeers(account, nsGroup.Groups) { + if am.anyGroupHasPeers(account, nsGroup.Groups) { am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) @@ -279,9 +279,9 @@ func validateDomain(domain string) error { } // areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers. -func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool { +func (am *DefaultAccountManager) areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool { if !newNSGroup.Enabled && !oldNSGroup.Enabled { return false } - return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups) + return am.anyGroupHasPeers(account, newNSGroup.Groups) || am.anyGroupHasPeers(account, oldNSGroup.Groups) } diff --git a/management/server/peer.go b/management/server/peer.go index 33f27d8c7e0..873b460ebae 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -613,7 +613,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return nil, nil, nil, err } - postureChecks := am.getPeerPostureChecks(account, newPeer) + postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, newPeer.ID) + if err != nil { + return nil, nil, nil, err + } + customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) return newPeer, networkMap, postureChecks, nil @@ -695,7 +699,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac if err != nil { return nil, nil, nil, err } - postureChecks = am.getPeerPostureChecks(account, peer) + + postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID) + if err != nil { + return nil, nil, nil, err + } customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil @@ -868,7 +876,11 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is if err != nil { return nil, nil, nil, err } - postureChecks = am.getPeerPostureChecks(account, peer) + + postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID) + if err != nil { + return nil, nil, nil, err + } customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil @@ -1021,7 +1033,12 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account defer wg.Done() defer func() { <-semaphore }() - postureChecks := am.getPeerPostureChecks(account, p) + postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, p.ID) + if err != nil { + log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err) + return + } + remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) diff --git a/management/server/policy.go b/management/server/policy.go index 8a5733f011c..c7872591d5e 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -405,7 +405,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) - if anyGroupHasPeers(account, policy.ruleGroups()) { + if am.anyGroupHasPeers(account, policy.ruleGroups()) { am.updateAccountPeers(ctx, accountID) } @@ -469,7 +469,7 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli if !policyToSave.Enabled && !oldPolicy.Enabled { return false, nil } - updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups()) + updateAccountPeers := am.anyGroupHasPeers(account, oldPolicy.ruleGroups()) || am.anyGroupHasPeers(account, policyToSave.ruleGroups()) return updateAccountPeers, nil } @@ -477,7 +477,7 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli // Add the new policy to the account account.Policies = append(account.Policies, policyToSave) - return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil + return am.anyGroupHasPeers(account, policyToSave.ruleGroups()), nil } func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { diff --git a/management/server/route.go b/management/server/route.go index dcf2cb0d32c..ecb562645e6 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -237,7 +237,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return nil, err } - if isRouteChangeAffectPeers(account, &newRoute) { + if am.isRouteChangeAffectPeers(account, &newRoute) { am.updateAccountPeers(ctx, accountID) } @@ -323,7 +323,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return err } - if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) { + if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) { am.updateAccountPeers(ctx, accountID) } @@ -355,7 +355,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) - if isRouteChangeAffectPeers(account, routy) { + if am.isRouteChangeAffectPeers(account, routy) { am.updateAccountPeers(ctx, accountID) } @@ -651,6 +651,6 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo { // isRouteChangeAffectPeers checks if a given route affects peers by determining // if it has a routing peer, distribution, or peer groups that include peers -func isRouteChangeAffectPeers(account *Account, route *route.Route) bool { - return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" +func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *Account, route *route.Route) bool { + return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" } From 601d429d8299302026e775a303792f38364eecaa Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 11 Nov 2024 16:26:12 +0300 Subject: [PATCH 17/28] fix tests Signed-off-by: bcmmbaga --- management/server/http/posture_checks_handler_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/management/server/http/posture_checks_handler_test.go b/management/server/http/posture_checks_handler_test.go index 02f0f0d8308..f400cec8154 100644 --- a/management/server/http/posture_checks_handler_test.go +++ b/management/server/http/posture_checks_handler_test.go @@ -40,15 +40,15 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH } return p, nil }, - SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error { + SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { postureChecks.ID = "postureCheck" testPostureChecks[postureChecks.ID] = postureChecks if err := postureChecks.Validate(); err != nil { - return status.Errorf(status.InvalidArgument, err.Error()) //nolint + return nil, status.Errorf(status.InvalidArgument, err.Error()) //nolint } - return nil + return postureChecks, nil }, DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error { _, ok := testPostureChecks[postureChecksID] From 664d1388aab5f283684b09ad0e47560dfab48df7 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 11 Nov 2024 22:29:59 +0300 Subject: [PATCH 18/28] fix merge Signed-off-by: bcmmbaga --- management/server/sql_store.go | 22 ++++++++++++---------- management/server/status/error.go | 1 - 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 730fb990059..502a83f2e32 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -33,12 +33,13 @@ import ( ) const ( - storeSqliteFileName = "store.db" - idQueryCondition = "id = ?" - keyQueryCondition = "key = ?" - accountAndIDQueryCondition = "account_id = ? and id = ?" - accountIDCondition = "account_id = ?" - peerNotFoundFMT = "peer %s not found" + storeSqliteFileName = "store.db" + idQueryCondition = "id = ?" + keyQueryCondition = "key = ?" + accountAndIDQueryCondition = "account_id = ? and id = ?" + accountAndIDsQueryCondition = "account_id = ? AND id IN ?" + accountIDCondition = "account_id = ?" + peerNotFoundFMT = "peer %s not found" ) // SqlStore represents an account storage backed by a Sql DB persisted to disk @@ -1095,10 +1096,11 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength } func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { - return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error) + log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error) + return status.Errorf(status.Internal, "failed to increment network serial count in store") } return nil } @@ -1213,7 +1215,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren // GetGroupsByIDs retrieves groups by their IDs and account ID. func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) { var groups []*nbgroup.Group - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, "account_id = ? AND id in ?", accountID, groupIDs) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store") @@ -1256,7 +1258,7 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength // DeleteGroups deletes groups from the database. func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error { result := s.db.Clauses(clause.Locking{Strength: string(strength)}). - Delete(&nbgroup.Group{}, " account_id = ? AND id IN ?", accountID, groupIDs) + Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error) return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error) diff --git a/management/server/status/error.go b/management/server/status/error.go index 6957a7e0558..db6e4c2fb5a 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -3,7 +3,6 @@ package status import ( "errors" "fmt" - "time" ) const ( From ab00c41dada6f97d13a3f53f3937071b21621c90 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 11 Nov 2024 22:38:24 +0300 Subject: [PATCH 19/28] fix sonar Signed-off-by: bcmmbaga --- management/server/group.go | 23 +++++++++++++++++++---- management/server/group/group.go | 6 ++---- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/management/server/group.go b/management/server/group.go index c49bb247186..1afb8f3c5e9 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -89,6 +89,10 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user return status.NewUserNotPartOfAccountError() } + if user.IsRegularUser() { + return status.NewAdminPermissionError() + } + var eventsToStore []func() var groupsToSave []*nbgroup.Group var updateAccountPeers bool @@ -213,6 +217,10 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use return status.NewUserNotPartOfAccountError() } + if user.IsRegularUser() { + return status.NewAdminPermissionError() + } + var group *nbgroup.Group err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { @@ -260,6 +268,10 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us return status.NewUserNotPartOfAccountError() } + if user.IsRegularUser() { + return status.NewAdminPermissionError() + } + var allErrors error var groupIDsToDelete []string var deletedGroups []*nbgroup.Group @@ -438,6 +450,11 @@ func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup. return &GroupLinkError{"user", linkedUser.Id} } + return checkGroupLinkedToSettings(ctx, transaction, group) +} + +// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account. +func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error { dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID) if err != nil { return err @@ -452,10 +469,8 @@ func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup. return err } - if settings.Extra != nil { - if slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) { - return &GroupLinkError{"integrated validator", group.Name} - } + if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) { + return &GroupLinkError{"integrated validator", group.Name} } return nil diff --git a/management/server/group/group.go b/management/server/group/group.go index bb0f5b7b6e2..24c60d3ceef 100644 --- a/management/server/group/group.go +++ b/management/server/group/group.go @@ -55,8 +55,7 @@ func (g *Group) IsGroupAll() bool { return g.Name == "All" } -// AddPeer adds peerID to Peers if not already present, -// returning true if added. +// AddPeer adds peerID to Peers if not present, returning true if added. func (g *Group) AddPeer(peerID string) bool { if peerID == "" { return false @@ -72,8 +71,7 @@ func (g *Group) AddPeer(peerID string) bool { return true } -// RemovePeer removes peerID from Peers if present, -// returning true if removed. +// RemovePeer removes peerID from Peers if present, returning true if removed. func (g *Group) RemovePeer(peerID string) bool { for i, itemID := range g.Peers { if itemID == peerID { From 113c21b0e1b56f2c2cc9dd9dc2fd8b3a74365632 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 11 Nov 2024 22:57:24 +0300 Subject: [PATCH 20/28] Change setup key log level to debug for missing group Signed-off-by: bcmmbaga --- management/server/setupkey.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 554c66ba4fc..f055d877fe2 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -449,14 +449,14 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran modifiedGroups := slices.Concat(addedGroups, removedGroups) groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups) if err != nil { - log.WithContext(ctx).Errorf("issue getting groups for setup key events: %v", err) + log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err) return nil } for _, g := range removedGroups { group, ok := groups[g] if !ok { - log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: %v", g, err) + log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: group not found", g) continue } @@ -469,7 +469,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran for _, g := range addedGroups { group, ok := groups[g] if !ok { - log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: %v", g, err) + log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: group not found", g) continue } From d23b5c892b923ad5b3a8f45368da13de26ef64ea Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 11 Nov 2024 22:58:22 +0300 Subject: [PATCH 21/28] Retrieve modified peers once for group events Signed-off-by: bcmmbaga --- management/server/group.go | 35 ++++++++++++++++++++-------------- management/server/sql_store.go | 17 +++++++++++++++++ management/server/store.go | 1 + 3 files changed, 39 insertions(+), 14 deletions(-) diff --git a/management/server/group.go b/management/server/group.go index 1afb8f3c5e9..57960e7f94a 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -156,34 +156,41 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac }) } + modifiedPeers := slices.Concat(addedPeers, removedPeers) + peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers) + if err != nil { + log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err) + return nil + } + for _, peerID := range addedPeers { - peer, err := transaction.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID) - if err != nil { - log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: %v", peerID, err) + peer, ok := peers[peerID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: peer not found in store", peerID) continue } - meta := map[string]any{ - "group": newGroup.Name, "group_id": newGroup.ID, - "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), - } eventsToStore = append(eventsToStore, func() { + meta := map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta) }) } for _, peerID := range removedPeers { - peer, err := transaction.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID) - if err != nil { - log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: %v", peerID, err) + peer, ok := peers[peerID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: peer not found in store", peerID) continue } - meta := map[string]any{ - "group": newGroup.Name, "group_id": newGroup.ID, - "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), - } eventsToStore = append(eventsToStore, func() { + meta := map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta) }) } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 502a83f2e32..7c741d35c8e 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1095,6 +1095,23 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength return peer, nil } +// GetPeersByIDs retrieves peers by their IDs and account ID. +func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) { + var peers []*nbpeer.Peer + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountAndIDsQueryCondition, accountID, peerIDs) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get peers by ID's from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peers by ID's from the store") + } + + peersMap := make(map[string]*nbpeer.Peer) + for _, peer := range peers { + peersMap[peer.ID] = peer + } + + return peersMap, nil +} + func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) diff --git a/management/server/store.go b/management/server/store.go index 2a0c44c678d..71b0d457b4c 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -93,6 +93,7 @@ type Store interface { GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) + GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error From 2806d7316100fb59f6498ebd6aae6975faf62477 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Tue, 12 Nov 2024 13:38:34 +0300 Subject: [PATCH 22/28] Add tests Signed-off-by: bcmmbaga --- management/server/sql_store_test.go | 277 +++++++++++++++++++++++++++- 1 file changed, 274 insertions(+), 3 deletions(-) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 20409798b0e..114da1ee6f6 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -14,11 +14,10 @@ import ( "time" "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - nbdns "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" route2 "github.com/netbirdio/netbird/route" @@ -1293,3 +1292,275 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID) require.Error(t, err) } + +func TestSqlStore_GetGroupsByIDs(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + groupIDs []string + expectedCount int + }{ + { + name: "retrieve existing groups by existing IDs", + groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"}, + expectedCount: 2, + }, + { + name: "empty group IDs list", + groupIDs: []string{}, + expectedCount: 0, + }, + { + name: "non-existing group IDs", + groupIDs: []string{"nonexistent1", "nonexistent2"}, + expectedCount: 0, + }, + { + name: "mixed existing and non-existing group IDs", + groupIDs: []string{"cfefqs706sqkneg59g4g", "nonexistent"}, + expectedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + groups, err := store.GetGroupsByIDs(context.Background(), LockingStrengthShare, accountID, tt.groupIDs) + require.NoError(t, err) + require.Len(t, groups, tt.expectedCount) + }) + } +} + +func TestSqlStore_SaveGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + group := &nbgroup.Group{ + ID: "group-id", + AccountID: accountID, + Issued: "api", + Peers: []string{"peer1", "peer2"}, + } + err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group) + require.NoError(t, err) + + savedGroup, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, "group-id") + require.NoError(t, err) + require.Equal(t, savedGroup, group) +} + +func TestSqlStore_SaveGroups(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + groups := []*nbgroup.Group{ + { + ID: "group-1", + AccountID: accountID, + Issued: "api", + Peers: []string{"peer1", "peer2"}, + }, + { + ID: "group-2", + AccountID: accountID, + Issued: "integration", + Peers: []string{"peer3", "peer4"}, + }, + } + err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups) + require.NoError(t, err) +} + +func TestSqlStore_DeleteGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + groupID string + expectError bool + }{ + { + name: "delete existing group", + groupID: "cfefqs706sqkneg59g4g", + expectError: false, + }, + { + name: "delete non-existing group", + groupID: "non-existing-group-id", + expectError: true, + }, + { + name: "delete with empty group ID", + groupID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := store.DeleteGroup(context.Background(), LockingStrengthUpdate, accountID, tt.groupID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + } else { + require.NoError(t, err) + + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, tt.groupID) + require.Error(t, err) + require.Nil(t, group) + } + }) + } +} + +func TestSqlStore_DeleteGroups(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + groupIDs []string + expectError bool + }{ + { + name: "delete multiple existing groups", + groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"}, + expectError: false, + }, + { + name: "delete non-existing groups", + groupIDs: []string{"non-existing-id-1", "non-existing-id-2"}, + expectError: false, + }, + { + name: "delete with empty group IDs list", + groupIDs: []string{}, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := store.DeleteGroups(context.Background(), LockingStrengthUpdate, accountID, tt.groupIDs) + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + + for _, groupID := range tt.groupIDs { + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.Error(t, err) + require.Nil(t, group) + } + } + }) + } +} + +func TestSqlStore_GetPeerByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + peerID string + expectError bool + }{ + { + name: "retrieve existing peer", + peerID: "cfefqs706sqkneg59g4g", + expectError: false, + }, + { + name: "retrieve non-existing peer", + peerID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty peer ID", + peerID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, tt.peerID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, peer) + } else { + require.NoError(t, err) + require.NotNil(t, peer) + require.Equal(t, tt.peerID, peer.ID) + } + }) + } +} + +func TestSqlStore_GetPeersByIDs(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + peerIDs []string + expectedCount int + }{ + { + name: "retrieve existing peers by existing IDs", + peerIDs: []string{"cfefqs706sqkneg59g4g", "cfeg6sf06sqkneg59g50"}, + expectedCount: 2, + }, + { + name: "empty peer IDs list", + peerIDs: []string{}, + expectedCount: 0, + }, + { + name: "non-existing peer IDs", + peerIDs: []string{"nonexistent1", "nonexistent2"}, + expectedCount: 0, + }, + { + name: "mixed existing and non-existing peer IDs", + peerIDs: []string{"cfeg6sf06sqkneg59g50", "nonexistent"}, + expectedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetPeersByIDs(context.Background(), LockingStrengthShare, accountID, tt.peerIDs) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } +} From a3abc211b3ceee7da582721a6917bf46d23ce19e Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Tue, 12 Nov 2024 17:11:56 +0300 Subject: [PATCH 23/28] Add tests Signed-off-by: bcmmbaga --- management/server/sql_store.go | 5 +- management/server/sql_store_test.go | 135 ++++++++++++++++++ management/server/testdata/extended-store.sql | 1 + 3 files changed, 137 insertions(+), 4 deletions(-) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 81dc704c213..f971f830088 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1324,11 +1324,8 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin // SavePostureChecks saves a posture checks to the database. func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error { - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) if result.Error != nil { - if errors.Is(result.Error, gorm.ErrDuplicatedKey) { - return status.Errorf(status.InvalidArgument, "name should be unique") - } log.WithContext(ctx).Errorf("failed to save posture checks to store: %s", result.Error) return status.Errorf(status.Internal, "failed to save posture checks to store") } diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 114da1ee6f6..94c4da6a82c 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -16,6 +16,7 @@ import ( "github.com/google/uuid" nbdns "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/posture" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -1564,3 +1565,137 @@ func TestSqlStore_GetPeersByIDs(t *testing.T) { }) } } + +func TestSqlStore_GetPostureChecksByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + postureChecksID string + expectError bool + }{ + { + name: "retrieve existing posture checks", + postureChecksID: "csplshq7qv948l48f7t0", + expectError: false, + }, + { + name: "retrieve non-existing posture checks", + postureChecksID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty posture checks ID", + postureChecksID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peer, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, peer) + } else { + require.NoError(t, err) + require.NotNil(t, peer) + require.Equal(t, tt.postureChecksID, peer.ID) + } + }) + } +} + +func TestSqlStore_SavePostureChecks(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + postureChecks := &posture.Checks{ + ID: "posture-checks-id", + AccountID: accountID, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.31.0", + }, + OSVersionCheck: &posture.OSVersionCheck{ + Ios: &posture.MinVersionCheck{ + MinVersion: "13.0.1", + }, + Linux: &posture.MinKernelVersionCheck{ + MinKernelVersion: "5.3.3-dev", + }, + }, + GeoLocationCheck: &posture.GeoLocationCheck{ + Locations: []posture.Location{ + { + CountryCode: "DE", + CityName: "Berlin", + }, + }, + Action: posture.CheckActionAllow, + }, + }, + } + err = store.SavePostureChecks(context.Background(), LockingStrengthUpdate, postureChecks) + require.NoError(t, err) + + savePostureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, "posture-checks-id") + require.NoError(t, err) + require.Equal(t, savePostureChecks, postureChecks) +} + +func TestSqlStore_DeletePostureChecks(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + postureChecksID string + expectError bool + }{ + { + name: "delete existing posture checks", + postureChecksID: "csplshq7qv948l48f7t0", + expectError: false, + }, + { + name: "delete non-existing posture checks", + postureChecksID: "non-existing-posture-checks-id", + expectError: true, + }, + { + name: "delete with empty posture checks ID", + postureChecksID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err = store.DeletePostureChecks(context.Background(), LockingStrengthUpdate, accountID, tt.postureChecksID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + } else { + require.NoError(t, err) + group, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID) + require.Error(t, err) + require.Nil(t, group) + } + }) + } +} diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql index b522741e7e0..1646ff4da6c 100644 --- a/management/server/testdata/extended-store.sql +++ b/management/server/testdata/extended-store.sql @@ -34,4 +34,5 @@ INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003' INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,''); +INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}'); INSERT INTO installations VALUES(1,''); From bbaee18cd56cbb5c35cbeda907c9ed4a05d6b482 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Tue, 12 Nov 2024 19:05:57 +0300 Subject: [PATCH 24/28] Fix typo Signed-off-by: bcmmbaga --- management/server/sql_store_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 94c4da6a82c..de939e8d0e9 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1596,17 +1596,17 @@ func TestSqlStore_GetPostureChecksByID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - peer, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID) + postureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID) if tt.expectError { require.Error(t, err) sErr, ok := status.FromError(err) require.True(t, ok) require.Equal(t, sErr.Type(), status.NotFound) - require.Nil(t, peer) + require.Nil(t, postureChecks) } else { require.NoError(t, err) - require.NotNil(t, peer) - require.Equal(t, tt.postureChecksID, peer.ID) + require.NotNil(t, postureChecks) + require.Equal(t, tt.postureChecksID, postureChecks.ID) } }) } From 9872bee41db34da97e075c2f4491b4db9d57d9b8 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Tue, 12 Nov 2024 23:53:29 +0300 Subject: [PATCH 25/28] Refactor anyGroupHasPeers to retrieve all groups once Signed-off-by: bcmmbaga --- management/server/group.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/management/server/group.go b/management/server/group.go index 5d301416902..758b28b760d 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -609,12 +609,12 @@ func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []s // anyGroupHasPeers checks if any of the given groups in the account have peers. func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) { - for _, groupID := range groupIDs { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) - if err != nil { - return false, err - } + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs) + if err != nil { + return false, err + } + for _, group := range groups { if group.HasPeers() { return true, nil } From ed047ec9dda048120edf4f074162a27136ac3cd6 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 13 Nov 2024 16:16:30 +0300 Subject: [PATCH 26/28] Add account locking and merge group deletion methods Signed-off-by: bcmmbaga --- management/server/group.go | 66 ++++++++++------------------------ management/server/sql_store.go | 2 +- 2 files changed, 20 insertions(+), 48 deletions(-) diff --git a/management/server/group.go b/management/server/group.go index 57960e7f94a..154a33b1350 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -215,48 +215,9 @@ func difference(a, b []string) []string { // DeleteGroup object of the peers. func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) - if err != nil { - return err - } - - if user.AccountID != accountID { - return status.NewUserNotPartOfAccountError() - } - - if user.IsRegularUser() { - return status.NewAdminPermissionError() - } - - var group *nbgroup.Group - - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - group, err = transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) - if err != nil { - return err - } - - if group.IsGroupAll() { - return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") - } - - if err = validateDeleteGroup(ctx, transaction, group, userID); err != nil { - return err - } - - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { - return err - } - - return transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID) - }) - if err != nil { - return err - } - - am.StoreEvent(ctx, userID, groupID, accountID, activity.GroupDeleted, group.EventMeta()) - - return nil + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + return am.DeleteGroups(ctx, accountID, userID, []string{groupID}) } // DeleteGroups deletes groups from an account. @@ -285,13 +246,14 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { for _, groupID := range groupIDs { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) + group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID) if err != nil { + allErrors = errors.Join(allErrors, err) continue } if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil { - allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) + allErrors = errors.Join(allErrors, err) continue } @@ -318,12 +280,15 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us // GroupAddPeer appends peer to the group func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + var group *nbgroup.Group var updateAccountPeers bool var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID) if err != nil { return err } @@ -356,12 +321,15 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr // GroupDeletePeer removes peer from the group func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + var group *nbgroup.Group var updateAccountPeers bool var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID) if err != nil { return err } @@ -430,13 +398,17 @@ func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup. if group.Issued == nbgroup.GroupIssuedIntegration { executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { - return status.Errorf(status.NotFound, "user not found") + return err } if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group") } } + if group.IsGroupAll() { + return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") + } + if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"route", string(linkedRoute.NetID)} } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 7c741d35c8e..0ebda6440c1 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1278,7 +1278,7 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error) - return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete groups from store") } return nil From a4d905ffe77881b682a4798d5564b89860404a0a Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 13 Nov 2024 16:56:22 +0300 Subject: [PATCH 27/28] Fix tests Signed-off-by: bcmmbaga --- management/server/group_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/group_test.go b/management/server/group_test.go index 89184e81927..59094a23e92 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -208,7 +208,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) { { name: "delete non-existent group", groupIDs: []string{"non-existent-group"}, - expectedDeleted: []string{"non-existent-group"}, + expectedReasons: []string{"group: non-existent-group not found"}, }, { name: "delete multiple groups with mixed results", From 51c1ec283cb9d9dacc9ec18ab6d98b64d954d362 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 15 Nov 2024 19:34:57 +0300 Subject: [PATCH 28/28] Add locks and remove log Signed-off-by: bcmmbaga --- management/server/posture_checks.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index d7b5a79a23b..59e726c4165 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -9,7 +9,6 @@ import ( "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" "github.com/rs/xid" - log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" ) @@ -32,6 +31,9 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID // SavePostureChecks saves a posture check. func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err @@ -85,6 +87,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI // DeletePostureChecks deletes a posture check by ID. func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err @@ -267,7 +272,6 @@ func isPeerInPolicySourceGroups(ctx context.Context, transaction Store, accountI for _, sourceGroup := range rule.Sources { group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup) if err != nil { - log.WithContext(ctx).Debugf("failed to check peer in policy source group: %v", err) return false, fmt.Errorf("failed to check peer in policy source group: %w", err) }