Skip to content

Commit

Permalink
[management] Remove redundant get account calls in GetAccountFromToken (
Browse files Browse the repository at this point in the history
#2615)

* refactor access control middleware and user access by JWT groups

Signed-off-by: bcmmbaga <[email protected]>

* refactor jwt groups extractor

Signed-off-by: bcmmbaga <[email protected]>

* refactor handlers to get account when necessary

Signed-off-by: bcmmbaga <[email protected]>

* refactor getAccountFromToken

Signed-off-by: bcmmbaga <[email protected]>

* refactor getAccountWithAuthorizationClaims

Signed-off-by: bcmmbaga <[email protected]>

* fix merge

Signed-off-by: bcmmbaga <[email protected]>

* revert handles change

Signed-off-by: bcmmbaga <[email protected]>

* remove GetUserByID from account manager

Signed-off-by: bcmmbaga <[email protected]>

* fix tests

Signed-off-by: bcmmbaga <[email protected]>

* refactor getAccountWithAuthorizationClaims to return account id

Signed-off-by: bcmmbaga <[email protected]>

* refactor handlers to use GetAccountIDFromToken

Signed-off-by: bcmmbaga <[email protected]>

* fix tests

Signed-off-by: bcmmbaga <[email protected]>

* remove locks

Signed-off-by: bcmmbaga <[email protected]>

* refactor

Signed-off-by: bcmmbaga <[email protected]>

* add GetGroupByName from store

Signed-off-by: bcmmbaga <[email protected]>

* add GetGroupByID from store and refactor

Signed-off-by: bcmmbaga <[email protected]>

* Refactor retrieval of policy and posture checks

Signed-off-by: bcmmbaga <[email protected]>

* Refactor user permissions and retrieves PAT

Signed-off-by: bcmmbaga <[email protected]>

* Refactor route, setupkey, nameserver and dns to get record(s) from store

Signed-off-by: bcmmbaga <[email protected]>

* Refactor store

Signed-off-by: bcmmbaga <[email protected]>

* fix lint

Signed-off-by: bcmmbaga <[email protected]>

* fix tests

Signed-off-by: bcmmbaga <[email protected]>

* fix add missing policy source posture checks

Signed-off-by: bcmmbaga <[email protected]>

* add store lock

Signed-off-by: bcmmbaga <[email protected]>

* fix tests

Signed-off-by: bcmmbaga <[email protected]>

* add get account

Signed-off-by: bcmmbaga <[email protected]>

---------

Signed-off-by: bcmmbaga <[email protected]>
  • Loading branch information
bcmmbaga authored Sep 27, 2024
1 parent 4ebf6e1 commit acb73bd
Show file tree
Hide file tree
Showing 44 changed files with 1,247 additions and 949 deletions.
383 changes: 235 additions & 148 deletions management/server/account.go

Large diffs are not rendered by default.

107 changes: 73 additions & 34 deletions management/server/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
assert.Equal(t, account.Id, ev.TargetID)
}

func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
type initUserParams jwtclaims.AuthorizationClaims

type test struct {
Expand Down Expand Up @@ -633,9 +633,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")

initAccount, err := manager.GetAccountByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
require.NoError(t, err, "create init user failed")

initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed")

if testCase.inputUpdateAttrs {
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
require.NoError(t, err, "update init user failed")
Expand All @@ -645,8 +648,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
testCase.inputClaims.AccountId = initAccount.Id
}

account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims)
accountID, _, err = manager.GetAccountIDFromToken(context.Background(), testCase.inputClaims)
require.NoError(t, err, "support function failed")

account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")

verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)

Expand All @@ -669,12 +676,13 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "unable to create account manager")

accountID := initAccount.Id
acc, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, accountID, domain)
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain)
require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization
// that happens inside the GetAccountByUserOrAccountID where the id is getting generated
// that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
initAccount = acc
initAccount, err = manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed")

claims := jwtclaims.AuthorizationClaims{
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
Expand All @@ -685,8 +693,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
}

t.Run("JWT groups disabled", func(t *testing.T) {
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")

account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")

require.Len(t, account.Groups, 1, "only ALL group should exists")
})

Expand All @@ -696,8 +708,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")

account, _, err := manager.GetAccountFromToken(context.Background(), claims)
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")

account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")

require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
})

Expand All @@ -708,8 +724,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")

account, _, err := manager.GetAccountFromToken(context.Background(), claims)
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")

account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")

require.Len(t, account.Groups, 3, "groups should be added to the account")

groupsByNames := map[string]*group.Group{}
Expand Down Expand Up @@ -874,21 +894,21 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {

userId := "test_user"

account, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, "", "")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "")
if err != nil {
t.Fatal(err)
}
if account == nil {
if accountID == "" {
t.Fatalf("expected to create an account for a user %s", userId)
return
}

_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
if err != nil {
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id)
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID)
}

_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", "", "")
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "")
if err == nil {
t.Errorf("expected an error when user and account IDs are empty")
}
Expand Down Expand Up @@ -1240,7 +1260,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}
}()

if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil {
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
Expand Down Expand Up @@ -1648,19 +1668,22 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")

account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")

assert.NotNil(t, account.Settings)
assert.Equal(t, account.Settings.PeerLoginExpirationEnabled, true)
assert.Equal(t, account.Settings.PeerLoginExpiration, 24*time.Hour)
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "unable to get account settings")

assert.NotNil(t, settings)
assert.Equal(t, settings.PeerLoginExpirationEnabled, true)
assert.Equal(t, settings.PeerLoginExpiration, 24*time.Hour)
}

func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")

_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")

key, err := wgtypes.GenerateKey()
Expand All @@ -1672,11 +1695,16 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
})
require.NoError(t, err, "unable to add peer")

account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account")

account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")

err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{

account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
Expand Down Expand Up @@ -1713,7 +1741,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")

account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")

key, err := wgtypes.GenerateKey()
Expand All @@ -1724,7 +1752,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
LoginExpirationEnabled: true,
})
require.NoError(t, err, "unable to add peer")
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
Expand All @@ -1741,8 +1769,12 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
},
}

account, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account")

account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")

// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
Expand All @@ -1757,7 +1789,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")

_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")

key, err := wgtypes.GenerateKey()
Expand All @@ -1769,8 +1801,12 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
})
require.NoError(t, err, "unable to add peer")

account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account")

account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")

err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")

Expand Down Expand Up @@ -1813,30 +1849,33 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")

account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")

updated, err := manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)

account, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
require.NoError(t, err, "unable to get account by ID")

assert.False(t, account.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour)
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "unable to get account settings")

_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
assert.False(t, settings.PeerLoginExpirationEnabled)
assert.Equal(t, settings.PeerLoginExpiration, time.Hour)

_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Second,
PeerLoginExpirationEnabled: false,
})
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")

_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour * 24 * 181,
PeerLoginExpirationEnabled: false,
})
Expand Down
16 changes: 4 additions & 12 deletions management/server/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,24 +80,16 @@ func (d DNSSettings) Copy() DNSSettings {

// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()

account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}

user, err := account.FindUser(userID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}

if !(user.HasAdminPower() || user.IsServiceUser) {
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings")
}
dnsSettings := account.DNSSettings.Copy()
return &dnsSettings, nil

return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
}

// SaveDNSSettings validates a user role and updates the account's DNS settings
Expand Down
Loading

0 comments on commit acb73bd

Please sign in to comment.