Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[management] Remove redundant get account calls in GetAccountFromToken #2615

Merged
merged 29 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
258b30c
refactor access control middleware and user access by JWT groups
bcmmbaga Sep 16, 2024
3cf1b02
refactor jwt groups extractor
bcmmbaga Sep 17, 2024
e5d55d3
refactor handlers to get account when necessary
bcmmbaga Sep 17, 2024
ccab3b4
refactor getAccountFromToken
bcmmbaga Sep 18, 2024
720d36a
refactor getAccountWithAuthorizationClaims
bcmmbaga Sep 18, 2024
a4c4158
Merge branch 'main' into refactor-get-account-by-token
bcmmbaga Sep 18, 2024
021fc8f
fix merge
bcmmbaga Sep 18, 2024
f60a423
revert handles change
bcmmbaga Sep 18, 2024
8f9c54f
remove GetUserByID from account manager
bcmmbaga Sep 18, 2024
9631cb4
fix tests
bcmmbaga Sep 18, 2024
4d9bb7e
refactor getAccountWithAuthorizationClaims to return account id
bcmmbaga Sep 20, 2024
26dd045
Merge branch 'main' into refactor-get-account-by-token
bcmmbaga Sep 20, 2024
8f98add
refactor handlers to use GetAccountIDFromToken
bcmmbaga Sep 22, 2024
7601a17
fix tests
bcmmbaga Sep 22, 2024
d9f612d
remove locks
bcmmbaga Sep 23, 2024
2884038
refactor
bcmmbaga Sep 24, 2024
1ffe89d
add GetGroupByName from store
bcmmbaga Sep 24, 2024
7561706
add GetGroupByID from store and refactor
bcmmbaga Sep 24, 2024
eab8564
Refactor retrieval of policy and posture checks
bcmmbaga Sep 24, 2024
d14b855
Refactor user permissions and retrieves PAT
bcmmbaga Sep 24, 2024
16174f0
Refactor route, setupkey, nameserver and dns to get record(s) from store
bcmmbaga Sep 25, 2024
41b212f
Refactor store
bcmmbaga Sep 25, 2024
b815393
fix lint
bcmmbaga Sep 25, 2024
c384874
fix tests
bcmmbaga Sep 25, 2024
dc82c2d
fix add missing policy source posture checks
bcmmbaga Sep 26, 2024
871595d
Merge branch 'main' into refactor-get-account-by-token
bcmmbaga Sep 26, 2024
4575ae2
add store lock
bcmmbaga Sep 26, 2024
b1b2b0a
fix tests
bcmmbaga Sep 26, 2024
e90d9ce
add get account
bcmmbaga Sep 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@ import (
cacheStore "github.com/eko/gocache/v3/store"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"

"github.com/netbirdio/netbird/base62"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
Expand All @@ -41,6 +36,10 @@ import (
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
)

const (
Expand Down Expand Up @@ -1255,30 +1254,37 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return nil
}

// GetAccountIDByUserOrAccountID looks for an account by user or accountID, if no account is provided and
// userID doesn't have an account associated with it, one account is created
// domain is used to create a new account if no account is found
// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided.
// If an accountID is provided, it checks if the account exists and returns it.
// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID.
// If the user doesn't have an account, it creates one using the provided domain.
// Returns the account ID or an error if none is found or created.
func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) {
if accountID != "" {
_, _, err := am.Store.GetAccountDomainAndCategory(ctx, accountID)
exists, err := am.Store.AccountExists(ctx, accountID)
if err != nil {
return "", err
}
if !exists {
return "", status.Errorf(status.NotFound, "account %s does not exist", accountID)
}
return accountID, nil
} else if userID != "" {
}

if userID != "" {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
if err != nil {
return "", status.Errorf(status.NotFound, "account not found using user id: %s", userID)
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
}

err = am.addAccountIDToIDPAppMeta(ctx, userID, account)
if err != nil {
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
return "", err
}

return account.Id, nil
}

return "", status.Errorf(status.NotFound, "no valid user or account Id provided")
return "", status.Errorf(status.NotFound, "no valid userID or accountID provided")
}

func isNil(i idp.Manager) bool {
Expand Down Expand Up @@ -1808,6 +1814,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return nil
}

// TODO: Remove GetAccount after refactoring account peer's update
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()

Expand Down Expand Up @@ -1907,7 +1914,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
}

domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, claims.AccountId)
domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
if err != nil {
return "", err
}
Expand All @@ -1923,7 +1930,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)

// We checked if the domain has a primary account already
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, claims.Domain)
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
if err != nil {
// if NotFound we are good to continue, otherwise return error
e, ok := status.FromError(err)
Expand Down
19 changes: 14 additions & 5 deletions management/server/file_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@ import (
"sync"
"time"

"github.com/rs/xid"
log "github.com/sirupsen/logrus"

nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/util"
)
Expand Down Expand Up @@ -958,11 +957,11 @@ func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error {
return status.Errorf(status.Internal, "SaveGroups is not implemented")
}

func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ string) (string, error) {
func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ LockingStrength, _ string) (string, error) {
return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented")
}

func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, accountID string) (string, string, error) {
func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStrength, accountID string) (string, string, error) {
s.mux.Lock()
defer s.mux.Unlock()

Expand All @@ -973,3 +972,13 @@ func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, accountID str

return account.Domain, account.DomainCategory, nil
}

// AccountExists checks whether an account exists by the given ID.
func (s *FileStore) AccountExists(_ context.Context, id string) (bool, error) {
_, exists := s.Accounts[id]
return exists, nil
}

func (s *FileStore) UpdateAccount(_ context.Context, _ LockingStrength, _ *Account) error {
return nil
}
65 changes: 54 additions & 11 deletions management/server/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
}

func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) {
accountID, err := s.GetAccountIDByPrivateDomain(ctx, domain)
accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
if err != nil {
return nil, err
}
Expand All @@ -409,11 +409,12 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
return s.GetAccount(ctx, accountID)
}

func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, domain string) (string, error) {
var account Account

result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?",
strings.ToLower(domain), true, PrivateCategory)
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
var accountID string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id").
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
strings.ToLower(domain), true, PrivateCategory,
).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
Expand All @@ -422,7 +423,7 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, domain strin
return "", status.Errorf(status.Internal, "issue getting account from store")
}

return account.Id, nil
return accountID, nil
}

func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
Expand Down Expand Up @@ -671,9 +672,8 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
}

func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
var user User
var accountID string
result := s.db.Model(&user).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
Expand Down Expand Up @@ -1035,10 +1035,53 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store {
}
}

// UpdateAccount updates an existing account's domain, DNS settings, and settings fields.
func (s *SqlStore) UpdateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error {
updates := make(map[string]interface{})

if account.Domain != "" {
updates["domain"] = account.Domain
}

if account.DNSSettings.DisabledManagementGroups != nil {
updates["dns_settings"] = account.DNSSettings
}

if account.Settings != nil {
updates["settings"] = account.Settings
}

if len(updates) == 0 {
return nil
}

result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Where("id = ?", account.Id).Updates(updates)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to update account: %v", result.Error)
}

if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "account not found")
}

return nil
}

// AccountExists checks whether an account exists by the given ID.
func (s *SqlStore) AccountExists(ctx context.Context, id string) (bool, error) {
var count int64
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, id).Count(&count)
if result.Error != nil {
return false, result.Error
}
return count > 0, nil
}
pascal-fischer marked this conversation as resolved.
Show resolved Hide resolved

// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, accountID string) (string, string, error) {
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
var account Account
result := s.db.WithContext(ctx).Model(&Account{}).Select("domain", "domain_category").
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
Where(idQueryCondition, accountID).First(&account)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
Expand Down
57 changes: 34 additions & 23 deletions management/server/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ const (
type Store interface {
GetAllAccounts(ctx context.Context) []*Account
GetAccount(ctx context.Context, accountID string) (*Account, error)
GetAccountDomainAndCategory(ctx context.Context, accountID string) (string, string, error)
DeleteAccount(ctx context.Context, account *Account) error
AccountExists(ctx context.Context, id string) (bool, error)
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
Expand All @@ -49,45 +49,56 @@ type Store interface {
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
GetAccountIDByPrivateDomain(ctx context.Context, domain string) (string, error)
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
UpdateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error
SaveAccount(ctx context.Context, account *Account) error
DeleteAccount(ctx context.Context, account *Account) error

GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
SaveAccount(ctx context.Context, account *Account) error
SaveUsers(accountID string, users map[string]*User) error
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error

GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error

GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)

GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error

GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error

GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementNetworkSerial(ctx context.Context, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)

GetInstallationID() string
SaveInstallationID(ctx context.Context, ID string) error

// AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock
AcquireWriteLockByUID(ctx context.Context, uniqueID string) func()
// AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock
AcquireReadLockByUID(ctx context.Context, uniqueID string) func()
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
AcquireGlobalLock(ctx context.Context) func()
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error

// Close should close the store persisting all unsaved data.
Close(ctx context.Context) error
// GetStoreEngine should return StoreEngine of the current store implementation.
// This is also a method of metrics.DataSource interface.
GetStoreEngine() StoreEngine
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
IncrementNetworkSerial(ctx context.Context, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
}

Expand Down