Skip to content

Commit

Permalink
[client/management] add peer lock to peer meta update and fix isEqual…
Browse files Browse the repository at this point in the history
… func (#2840)
  • Loading branch information
pascal-fischer authored Nov 15, 2024
1 parent 44e799c commit 4aee3c9
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 1 deletion.
12 changes: 12 additions & 0 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"reflect"
"runtime"
"slices"
"sort"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -1484,6 +1485,17 @@ func (e *Engine) stopDNSServer() {

// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
for _, check := range checks {
sort.Slice(check.Files, func(i, j int) bool {
return check.Files[i] < check.Files[j]
})
}
for _, oCheck := range oChecks {
sort.Slice(oCheck.Files, func(i, j int) bool {
return oCheck.Files[i] < oCheck.Files[j]
})
}

return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files)
})
Expand Down
93 changes: 93 additions & 0 deletions client/internal/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,99 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
}
}

func Test_CheckFilesEqual(t *testing.T) {
testCases := []struct {
name string
inputChecks1 []*mgmtProto.Checks
inputChecks2 []*mgmtProto.Checks
expectedBool bool
}{
{
name: "Equal Files In Equal Order Should Return True",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
expectedBool: true,
},
{
name: "Equal Files In Reverse Order Should Return True",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{
"testfile2",
"testfile1",
},
},
},
expectedBool: true,
},
{
name: "Unequal Files Should Return False",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile3",
},
},
},
expectedBool: false,
},
{
name: "Compared With Empty Should Return False",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{},
},
},
expectedBool: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
result := isChecksEqual(testCase.inputChecks1, testCase.inputChecks2)
assert.Equal(t, testCase.expectedBool, result, "result should match expected bool")
})
}
}

func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
Expand Down
5 changes: 4 additions & 1 deletion management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -2319,7 +2319,7 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account

err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
}

return nil
Expand All @@ -2335,6 +2335,9 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer unlock()

unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer unlockPeer()

account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
Expand Down
3 changes: 3 additions & 0 deletions management/server/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context

account.UpdatePeer(peer)

log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected)

err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
if err != nil {
return false, fmt.Errorf("failed to save peer status: %w", err)
Expand Down Expand Up @@ -657,6 +659,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac

updated := peer.UpdateMetaIfNew(sync.Meta)
if updated {
log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
err = am.Store.SavePeer(ctx, account.Id, peer)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to save peer: %w", err)
Expand Down

0 comments on commit 4aee3c9

Please sign in to comment.