diff --git a/management/server/file_store.go b/management/server/file_store.go index 316feb867be..a18e0e53923 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -983,6 +983,9 @@ func (s *FileStore) UpdateAccount(_ context.Context, _ LockingStrength, _ *Accou return nil } +func (s *FileStore) GetGroupByID(_ context.Context, _, _ string) (*nbgroup.Group, error) { + return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented") +} func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) { return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented") } diff --git a/management/server/group.go b/management/server/group.go index 9343f2dd2f8..60d895d0ab4 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -25,36 +25,38 @@ func (e *GroupLinkError) Error() string { return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name) } -// GetGroup object of the peers -func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) { - groups, err := am.GetAllGroups(ctx, accountID, userID) +// CheckGroupPermissions validates if a user has the necessary permissions to view groups +func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error { + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { - return nil, err + return err } - for _, group := range groups { - if group.ID == groupID { - return group, nil - } + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return err + } + + if !user.HasAdminPower() && !user.IsServiceUser && settings.RegularUsersViewBlocked { + return status.Errorf(status.PermissionDenied, "groups are blocked for users") } - return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID) + return nil } -// GetAllGroups returns all groups in an account -func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) { - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) - if err != nil { +// GetGroup returns a specific group by groupID in an account +func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) { + if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) - if err != nil { - return nil, err - } + return am.Store.GetGroupByID(ctx, groupID, accountID) +} - if !user.HasAdminPower() && !user.IsServiceUser && settings.RegularUsersViewBlocked { - return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users") +// GetAllGroups returns all groups in an account +func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { + if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { + return nil, err } return am.Store.GetAccountGroups(ctx, accountID) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index b76846c9fb9..d843e6f1d29 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1087,12 +1087,24 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", "", status.Errorf(status.NotFound, "account not found") } - return "", "", status.Errorf(status.Internal, "failed to retrieve account fields") + return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error) } return account.Domain, account.DomainCategory, nil } +func (s *SqlStore) GetGroupByID(ctx context.Context, groupID, accountID string) (*nbgroup.Group, error) { + var group nbgroup.Group + result := s.db.WithContext(ctx).Model(&nbgroup.Group{}).Where(accountAndIDQueryCondition, accountID, groupID).First(&group) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "group not found") + } + return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error) + } + return &group, nil +} + // GetGroupByName retrieves a group by name and account ID. func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) { var group nbgroup.Group @@ -1102,7 +1114,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "group not found") } - return nil, status.Errorf(status.Internal, "failed to retrieve group fields") + return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error) } return &group, nil } diff --git a/management/server/store.go b/management/server/store.go index 10a52db98fa..73e68531c57 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -64,6 +64,7 @@ type Store interface { DeleteTokenID2UserIDIndex(tokenID string) error GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) + GetGroupByID(ctx context.Context, groupID, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error