Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[management] Refactor posture check to use store methods #2874

Merged
merged 38 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
389c961
Refactor setup key handling to use store methods
bcmmbaga Nov 7, 2024
78044c2
add lock to get account groups
bcmmbaga Nov 7, 2024
1a5f3c6
add check for regular user
bcmmbaga Nov 7, 2024
931521d
get only required groups for auto-group validation
bcmmbaga Nov 7, 2024
f8b5eed
add account lock and return auto groups map on validation
bcmmbaga Nov 8, 2024
106fc75
refactor account peers update
bcmmbaga Nov 8, 2024
0a70e4c
Refactor groups to use store methods
bcmmbaga Nov 8, 2024
8126d95
refactor GetGroupByID and add NewGroupNotFoundError
bcmmbaga Nov 8, 2024
ac05f69
fix tests
bcmmbaga Nov 8, 2024
40af1a5
Merge branch 'feature/get-account-refactoring' into setupkey-get-acco…
bcmmbaga Nov 8, 2024
d58cf50
Merge branch 'setupkey-get-account-refactoring' into groups-get-accou…
bcmmbaga Nov 8, 2024
7100be8
Add AddPeer and RemovePeer methods to Group struct
bcmmbaga Nov 8, 2024
6dc185e
Preserve store engine in SqlStore transactions
bcmmbaga Nov 8, 2024
bdeb95c
Run groups ops in transaction
bcmmbaga Nov 8, 2024
3ed8b9c
fix missing group removed from setup key activity
bcmmbaga Nov 8, 2024
cc04aef
Merge branch 'setupkey-get-account-refactoring' into groups-get-accou…
bcmmbaga Nov 8, 2024
871500c
fix merge
bcmmbaga Nov 8, 2024
174e07f
Refactor posture checks to remove get and save account
bcmmbaga Nov 11, 2024
d54b696
fix refactor
bcmmbaga Nov 11, 2024
601d429
fix tests
bcmmbaga Nov 11, 2024
010a8bf
Merge branch 'main' into groups-get-account-refactoring
bcmmbaga Nov 11, 2024
664d138
fix merge
bcmmbaga Nov 11, 2024
ab00c41
fix sonar
bcmmbaga Nov 11, 2024
113c21b
Change setup key log level to debug for missing group
bcmmbaga Nov 11, 2024
d23b5c8
Retrieve modified peers once for group events
bcmmbaga Nov 11, 2024
2806d73
Add tests
bcmmbaga Nov 12, 2024
00023bf
Merge branch 'groups-get-account-refactoring' into posturechecks-get-…
bcmmbaga Nov 12, 2024
a3abc21
Add tests
bcmmbaga Nov 12, 2024
ed259a6
Merge branch 'main' into groups-get-account-refactoring
bcmmbaga Nov 12, 2024
446de5e
Merge branch 'groups-get-account-refactoring' into posturechecks-get-…
bcmmbaga Nov 12, 2024
bbaee18
Fix typo
bcmmbaga Nov 12, 2024
9872bee
Refactor anyGroupHasPeers to retrieve all groups once
bcmmbaga Nov 12, 2024
ed047ec
Add account locking and merge group deletion methods
bcmmbaga Nov 13, 2024
a4d905f
Fix tests
bcmmbaga Nov 13, 2024
92b9e11
Merge branch 'main' into groups-get-account-refactoring
bcmmbaga Nov 15, 2024
51c1ec2
Add locks and remove log
bcmmbaga Nov 15, 2024
1ff8f61
Merge branch 'main' into groups-get-account-refactoring
bcmmbaga Nov 15, 2024
d4c7124
Merge branch 'groups-get-account-refactoring' into posturechecks-get-…
bcmmbaga Nov 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ type AccountManager interface {
HasConnectedChannel(peerID string) bool
GetExternalCacheManager() ExternalCacheManager
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetIdpManager() idp.Manager
Expand Down
2 changes: 1 addition & 1 deletion management/server/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
}

if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
if am.anyGroupHasPeers(account, addedGroups) || am.anyGroupHasPeers(account, removedGroups) {
am.updateAccountPeers(ctx, accountID)
}

Expand Down
19 changes: 17 additions & 2 deletions management/server/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -566,12 +566,27 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountI
return false, nil
}

// anyGroupHasPeers checks if any of the given groups in the account have peers.
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
return true
}
}
return false
}

// anyGroupHasPeers checks if any of the given groups in the account have peers.
func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs)
if err != nil {
return false, err
}

for _, group := range groups {
if group.HasPeers() {
return true, nil
}
}

return false, nil
}
3 changes: 2 additions & 1 deletion management/server/http/posture_checks_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.
return
}

if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
Expand Down
6 changes: 3 additions & 3 deletions management/server/http/posture_checks_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
}
return p, nil
},
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error {
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
postureChecks.ID = "postureCheck"
testPostureChecks[postureChecks.ID] = postureChecks

if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
return nil, status.Errorf(status.InvalidArgument, err.Error()) //nolint
}

return nil
return postureChecks, nil
},
DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error {
_, ok := testPostureChecks[postureChecksID]
Expand Down
6 changes: 3 additions & 3 deletions management/server/mock_server/account_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ type MockAccountManager struct {
HasConnectedChannelFunc func(peerID string) bool
GetExternalCacheManagerFunc func() server.ExternalCacheManager
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetIdpManagerFunc func() idp.Manager
Expand Down Expand Up @@ -730,11 +730,11 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p
}

// SavePostureChecks mocks SavePostureChecks of the AccountManager interface
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
if am.SavePostureChecksFunc != nil {
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks)
}
return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
return nil, status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
}

// DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface
Expand Down
10 changes: 5 additions & 5 deletions management/server/nameserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
return nil, err
}

if anyGroupHasPeers(account, newNSGroup.Groups) {
if am.anyGroupHasPeers(account, newNSGroup.Groups) {
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
Expand Down Expand Up @@ -105,7 +105,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return err
}

if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
if am.areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
Expand Down Expand Up @@ -135,7 +135,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
return err
}

if anyGroupHasPeers(account, nsGroup.Groups) {
if am.anyGroupHasPeers(account, nsGroup.Groups) {
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
Expand Down Expand Up @@ -279,9 +279,9 @@ func validateDomain(domain string) error {
}

// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
func (am *DefaultAccountManager) areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false
}
return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups)
return am.anyGroupHasPeers(account, newNSGroup.Groups) || am.anyGroupHasPeers(account, oldNSGroup.Groups)
}
25 changes: 21 additions & 4 deletions management/server/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, err
}

postureChecks := am.getPeerPostureChecks(account, newPeer)
postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, newPeer.ID)
if err != nil {
return nil, nil, nil, err
}

customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
return newPeer, networkMap, postureChecks, nil
Expand Down Expand Up @@ -700,7 +704,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err)
}
postureChecks = am.getPeerPostureChecks(account, peer)

postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID)
if err != nil {
return nil, nil, nil, err
}

customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
Expand Down Expand Up @@ -873,7 +881,11 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
if err != nil {
return nil, nil, nil, err
}
postureChecks = am.getPeerPostureChecks(account, peer)

postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID)
if err != nil {
return nil, nil, nil, err
}

customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
Expand Down Expand Up @@ -1026,7 +1038,12 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
defer wg.Done()
defer func() { <-semaphore }()

postureChecks := am.getPeerPostureChecks(account, p)
postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, p.ID)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err)
return
}

remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
Expand Down
6 changes: 3 additions & 3 deletions management/server/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po

am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())

if anyGroupHasPeers(account, policy.ruleGroups()) {
if am.anyGroupHasPeers(account, policy.ruleGroups()) {
am.updateAccountPeers(ctx, accountID)
}

Expand Down Expand Up @@ -469,15 +469,15 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli
if !policyToSave.Enabled && !oldPolicy.Enabled {
return false, nil
}
updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups())
updateAccountPeers := am.anyGroupHasPeers(account, oldPolicy.ruleGroups()) || am.anyGroupHasPeers(account, policyToSave.ruleGroups())

return updateAccountPeers, nil
}

// Add the new policy to the account
account.Policies = append(account.Policies, policyToSave)

return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
return am.anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
}

func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
Expand Down
6 changes: 0 additions & 6 deletions management/server/posture/checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import (
"regexp"

"github.com/hashicorp/go-version"
"github.com/rs/xid"

"github.com/netbirdio/netbird/management/server/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
Expand Down Expand Up @@ -172,10 +170,6 @@ func NewChecksFromAPIPostureCheckUpdate(source api.PostureCheckUpdate, postureCh
}

func buildPostureCheck(postureChecksID string, name string, description string, checks api.Checks) (*Checks, error) {
if postureChecksID == "" {
postureChecksID = xid.New().String()
}

postureChecks := Checks{
ID: postureChecksID,
Name: name,
Expand Down
Loading
Loading