Skip to content

Commit

Permalink
refactor handlers to use GetAccountIDFromToken
Browse files Browse the repository at this point in the history
Signed-off-by: bcmmbaga <[email protected]>
  • Loading branch information
bcmmbaga committed Sep 22, 2024
1 parent 26dd045 commit 8f98add
Show file tree
Hide file tree
Showing 21 changed files with 465 additions and 362 deletions.
58 changes: 45 additions & 13 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,14 @@ type AccountManager interface {
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error)
GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error)
GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
DeleteAccount(ctx context.Context, accountID, userID string) error
MarkPATUsed(ctx context.Context, tokenID string) error
GetUserByID(ctx context.Context, id string) (*User, error)
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error)
ListUsers(ctx context.Context, accountID string) ([]*User, error)
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
Expand All @@ -107,7 +109,7 @@ type AccountManager interface {
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
Expand Down Expand Up @@ -145,6 +147,7 @@ type AccountManager interface {
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error)
}

type DefaultAccountManager struct {
Expand Down Expand Up @@ -1739,10 +1742,27 @@ func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token st
return account, user, pat, nil
}

// GetAccountFromToken returns an account associated with this token.
func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) {
// GetAccountByID returns an account associated with this account ID.
func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}

if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
}

unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()

return am.Store.GetAccount(ctx, accountID)
}

// GetAccountIDFromToken returns an account ID associated with this token.
func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
if claims.UserId == "" {
return nil, nil, fmt.Errorf("user ID is empty")
return "", "", fmt.Errorf("user ID is empty")
}
if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations.
Expand All @@ -1754,28 +1774,27 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims

accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims)
if err != nil {
return nil, nil, err
return "", "", err
}

user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
if err != nil {
// this is not really possible because we got an account by user ID
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
}

if !user.IsServiceUser && claims.Invited {
err = am.redeemInvite(ctx, accountID, user.Id)
if err != nil {
return nil, nil, err
return "", "", err
}
}

if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil {
return nil, nil, err
return "", "", err
}

// TODO: return account id, user id and error
return &Account{Id: accountID}, user, nil
return accountID, user.Id, nil
}

// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
Expand Down Expand Up @@ -2049,12 +2068,12 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
// group propagation and set the list of groups with access permissions.
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
account, _, err := am.GetAccountFromToken(ctx, claims)
accountID, _, err := am.GetAccountIDFromToken(ctx, claims)
if err != nil {
return err
}

settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, account.Id)
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
Expand Down Expand Up @@ -2133,6 +2152,19 @@ func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Stor
return newLabel, nil
}

func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}

if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
}

return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
}

// addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {
Expand Down
53 changes: 20 additions & 33 deletions management/server/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,70 +27,48 @@ func (e *GroupLinkError) Error() string {

// GetGroup object of the peers
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()

account, err := am.Store.GetAccount(ctx, accountID)
groups, err := am.GetAllGroups(ctx, accountID, userID)
if err != nil {
return nil, err
}

user, err := account.FindUser(userID)
if err != nil {
return nil, err
}

if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users")
}

group, ok := account.Groups[groupID]
if ok {
return group, nil
for _, group := range groups {
if group.ID == groupID {
return group, nil
}
}

return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}

// GetAllGroups returns all groups in an account
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()

account, err := am.Store.GetAccount(ctx, accountID)
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}

user, err := account.FindUser(userID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}

if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
if !user.HasAdminPower() && !user.IsServiceUser && settings.RegularUsersViewBlocked {
return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users")
}

groups := make([]*nbgroup.Group, 0, len(account.Groups))
for _, item := range account.Groups {
groups = append(groups, item)
}

return groups, nil
return am.Store.GetAccountGroups(ctx, accountID)
}

// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()

account, err := am.Store.GetAccount(ctx, accountID)
groups, err := am.Store.GetAccountGroups(ctx, accountID)
if err != nil {
return nil, err
}

matchingGroups := make([]*nbgroup.Group, 0)
for _, group := range account.Groups {
for _, group := range groups {
if group.Name == groupName {
matchingGroups = append(matchingGroups, group)
}
Expand Down Expand Up @@ -262,6 +240,15 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
return nil
}

allGroup, err := account.GetGroupAll()
if err != nil {
return err
}

if allGroup.ID == groupID {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}

if err = validateDeleteGroup(account, group, userId); err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion management/server/grpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
}
claims := s.jwtClaimsExtractor.FromToken(token)
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
_, _, err = s.accountManager.GetAccountFromToken(ctx, claims)
_, _, err = s.accountManager.GetAccountIDFromToken(ctx, claims)
if err != nil {
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
}
Expand Down
46 changes: 21 additions & 25 deletions management/server/http/accounts_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,26 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *
// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}

if !(user.HasAdminPower() || user.IsServiceUser) {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w)
settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}

resp := toAccountResponse(account)
resp := toAccountResponse(accountID, settings)
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
}

// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
_, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
_, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
Expand Down Expand Up @@ -96,24 +97,19 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
}

updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, user.Id, settings)
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}

resp := toAccountResponse(updatedAccount)
resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings)

util.WriteJSONObject(r.Context(), w, &resp)
}

// DeleteAccount is a HTTP DELETE handler to delete an account
func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}

claims := h.claimsExtractor.FromRequestContext(r)
vars := mux.Vars(r)
targetAccountID := vars["accountId"]
Expand All @@ -131,28 +127,28 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
util.WriteJSONObject(r.Context(), w, emptyObject{})
}

func toAccountResponse(account *server.Account) *api.Account {
jwtAllowGroups := account.Settings.JWTAllowGroups
func toAccountResponse(accountID string, settings *server.Settings) *api.Account {
jwtAllowGroups := settings.JWTAllowGroups
if jwtAllowGroups == nil {
jwtAllowGroups = []string{}
}

settings := api.AccountSettings{
PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()),
PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled,
GroupsPropagationEnabled: &account.Settings.GroupsPropagationEnabled,
JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled,
JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName,
apiSettings := api.AccountSettings{
PeerLoginExpiration: int(settings.PeerLoginExpiration.Seconds()),
PeerLoginExpirationEnabled: settings.PeerLoginExpirationEnabled,
GroupsPropagationEnabled: &settings.GroupsPropagationEnabled,
JwtGroupsEnabled: &settings.JWTGroupsEnabled,
JwtGroupsClaimName: &settings.JWTGroupsClaimName,
JwtAllowGroups: &jwtAllowGroups,
RegularUsersViewBlocked: account.Settings.RegularUsersViewBlocked,
RegularUsersViewBlocked: settings.RegularUsersViewBlocked,
}

if account.Settings.Extra != nil {
settings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &account.Settings.Extra.PeerApprovalEnabled}
if settings.Extra != nil {
apiSettings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &settings.Extra.PeerApprovalEnabled}
}

return &api.Account{
Id: account.Id,
Settings: settings,
Id: accountID,
Settings: apiSettings,
}
}
8 changes: 4 additions & 4 deletions management/server/http/dns_settings_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg
// GetDNSSettings returns the DNS settings for the account
func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}

dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id)
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
Expand All @@ -55,7 +55,7 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque
// UpdateDNSSettings handles update to DNS settings of an account
func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
Expand All @@ -72,7 +72,7 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re
DisabledManagementGroups: req.DisabledManagementGroups,
}

err = h.accountManager.SaveDNSSettings(r.Context(), account.Id, user.Id, updateDNSSettings)
err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings)
if err != nil {
util.WriteError(r.Context(), err, w)
return
Expand Down
Loading

0 comments on commit 8f98add

Please sign in to comment.