Skip to content

Commit

Permalink
Merge pull request netbirdio#745 from netbirdio/feature/pat_persistence
Browse files Browse the repository at this point in the history
PAT persistence
  • Loading branch information
pascal-fischer authored Mar 21, 2023
2 parents 4e17e72 + 47e9a23 commit 94e07bd
Show file tree
Hide file tree
Showing 10 changed files with 517 additions and 31 deletions.
54 changes: 51 additions & 3 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package server

import (
"context"
"crypto/sha256"
"fmt"
"hash/crc32"
"math/rand"
"net"
"net/netip"
Expand All @@ -12,17 +14,19 @@ import (
"sync"
"time"

"codeberg.org/ac/base62"
"github.com/eko/gocache/v3/cache"
cacheStore "github.com/eko/gocache/v3/store"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"

nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
)

const (
Expand Down Expand Up @@ -50,6 +54,7 @@ type AccountManager interface {
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
GetAccountFromPAT(pat string) (*Account, *User, error)
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExists(accountId string) (*bool, error)
GetPeerByKey(peerKey string) (*Peer, error)
Expand All @@ -61,6 +66,8 @@ type AccountManager interface {
GetNetworkMap(peerID string) (*NetworkMap, error)
GetPeerNetwork(peerID string) (*Network, error)
AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error)
AddPATToUser(accountID string, userID string, pat *PersonalAccessToken) error
DeletePAT(accountID string, userID string, tokenID string) error
UpdatePeerSSHKey(peerID string, sshKey string) error
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
GetGroup(accountId, groupID string) (*Group, error)
Expand Down Expand Up @@ -1112,6 +1119,47 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
return nil
}

// GetAccountFromPAT returns Account and User associated with a personal access token
func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *User, error) {
if len(token) != PATLength {
return nil, nil, fmt.Errorf("token has wrong length")
}

prefix := token[:len(PATPrefix)]
if prefix != PATPrefix {
return nil, nil, fmt.Errorf("token has wrong prefix")
}
secret := token[len(PATPrefix) : len(PATPrefix)+PATSecretLength]
encodedChecksum := token[len(PATPrefix)+PATSecretLength : len(PATPrefix)+PATSecretLength+PATChecksumLength]

verificationChecksum, err := base62.Decode(encodedChecksum)
if err != nil {
return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
}

secretChecksum := crc32.ChecksumIEEE([]byte(secret))
if secretChecksum != verificationChecksum {
return nil, nil, fmt.Errorf("token checksum does not match")
}

hashedToken := sha256.Sum256([]byte(token))
tokenID, err := am.Store.GetTokenIDByHashedToken(string(hashedToken[:]))
if err != nil {
return nil, nil, err
}

user, err := am.Store.GetUserByTokenID(tokenID)
if err != nil {
return nil, nil, err
}

account, err := am.Store.GetAccountByUser(user.Id)
if err != nil {
return nil, nil, err
}
return account, user, nil
}

// GetAccountFromToken returns an account associated with this token
func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) {
if claims.UserId == "" {
Expand Down
38 changes: 36 additions & 2 deletions management/server/account_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"crypto/sha256"
"fmt"
"net"
"reflect"
Expand Down Expand Up @@ -458,6 +459,39 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
}
}

func TestAccountManager_GetAccountFromPAT(t *testing.T) {
store := newStore(t)
account := newAccountWithId("account_id", "testuser", "")

token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token))
account.Users["someUser"] = &User{
Id: "someUser",
PATs: map[string]*PersonalAccessToken{
"pat1": {
ID: "tokenId",
HashedToken: string(hashedToken[:]),
},
},
}
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}

am := DefaultAccountManager{
Store: store,
}

account, user, err := am.GetAccountFromPAT(token)
if err != nil {
t.Fatalf("Error when getting Account from PAT: %s", err)
}

assert.Equal(t, "account_id", account.Id)
assert.Equal(t, "someUser", user.Id)
}

func TestAccountManager_PrivateAccount(t *testing.T) {
manager, err := createManager(t)
if err != nil {
Expand Down Expand Up @@ -1208,8 +1242,8 @@ func TestAccount_Copy(t *testing.T) {
Id: "user1",
Role: UserRoleAdmin,
AutoGroups: []string{"group1"},
PATs: []PersonalAccessToken{
{
PATs: map[string]*PersonalAccessToken{
"pat1": {
ID: "pat1",
Description: "First PAT",
HashedToken: "SoMeHaShEdToKeN",
Expand Down
100 changes: 84 additions & 16 deletions management/server/file_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ import (
"sync"
"time"

"github.com/netbirdio/netbird/management/server/status"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/management/server/status"

"github.com/netbirdio/netbird/util"
)

Expand All @@ -25,6 +26,8 @@ type FileStore struct {
PeerID2AccountID map[string]string `json:"-"`
UserID2AccountID map[string]string `json:"-"`
PrivateDomain2AccountID map[string]string `json:"-"`
HashedPAT2TokenID map[string]string `json:"-"`
TokenID2UserID map[string]string `json:"-"`
InstallationID string

// mutex to synchronise Store read/write operations
Expand Down Expand Up @@ -57,6 +60,8 @@ func restore(file string) (*FileStore, error) {
UserID2AccountID: make(map[string]string),
PrivateDomain2AccountID: make(map[string]string),
PeerID2AccountID: make(map[string]string),
HashedPAT2TokenID: make(map[string]string),
TokenID2UserID: make(map[string]string),
storeFile: file,
}

Expand All @@ -80,6 +85,8 @@ func restore(file string) (*FileStore, error) {
store.UserID2AccountID = make(map[string]string)
store.PrivateDomain2AccountID = make(map[string]string)
store.PeerID2AccountID = make(map[string]string)
store.HashedPAT2TokenID = make(map[string]string)
store.TokenID2UserID = make(map[string]string)

for accountID, account := range store.Accounts {
if account.Settings == nil {
Expand All @@ -103,9 +110,10 @@ func restore(file string) (*FileStore, error) {
}
for _, user := range account.Users {
store.UserID2AccountID[user.Id] = accountID
}
for _, user := range account.Users {
store.UserID2AccountID[user.Id] = accountID
for _, pat := range user.PATs {
store.TokenID2UserID[pat.ID] = user.Id
store.HashedPAT2TokenID[pat.HashedToken[:]] = pat.ID
}
}

if account.Domain != "" && account.DomainCategory == PrivateCategory &&
Expand Down Expand Up @@ -258,6 +266,10 @@ func (s *FileStore) SaveAccount(account *Account) error {

for _, user := range accountCopy.Users {
s.UserID2AccountID[user.Id] = accountCopy.Id
for _, pat := range user.PATs {
s.TokenID2UserID[pat.ID] = user.Id
s.HashedPAT2TokenID[pat.HashedToken[:]] = pat.ID
}
}

if accountCopy.DomainCategory == PrivateCategory && accountCopy.IsDomainPrimaryAccount {
Expand All @@ -276,13 +288,33 @@ func (s *FileStore) SaveAccount(account *Account) error {
return s.persist(s.storeFile)
}

// DeleteHashedPAT2TokenIDIndex removes an entry from the indexing map HashedPAT2TokenID
func (s *FileStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
s.mux.Lock()
defer s.mux.Unlock()

delete(s.HashedPAT2TokenID, hashedToken)

return s.persist(s.storeFile)
}

// DeleteTokenID2UserIDIndex removes an entry from the indexing map TokenID2UserID
func (s *FileStore) DeleteTokenID2UserIDIndex(tokenID string) error {
s.mux.Lock()
defer s.mux.Unlock()

delete(s.TokenID2UserID, tokenID)

return s.persist(s.storeFile)
}

// GetAccountByPrivateDomain returns account by private domain
func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()

accountID, accountIDFound := s.PrivateDomain2AccountID[strings.ToLower(domain)]
if !accountIDFound {
accountID, ok := s.PrivateDomain2AccountID[strings.ToLower(domain)]
if !ok {
return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
}

Expand All @@ -299,8 +331,8 @@ func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()

accountID, accountIDFound := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !accountIDFound {
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !ok {
return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
}

Expand All @@ -312,6 +344,42 @@ func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
return account.Copy(), nil
}

// GetTokenIDByHashedToken returns the id of a personal access token by its hashed secret
func (s *FileStore) GetTokenIDByHashedToken(token string) (string, error) {
s.mux.Lock()
defer s.mux.Unlock()

tokenID, ok := s.HashedPAT2TokenID[token]
if !ok {
return "", status.Errorf(status.NotFound, "tokenID not found: provided token doesn't exists")
}

return tokenID, nil
}

// GetUserByTokenID returns a User object a tokenID belongs to
func (s *FileStore) GetUserByTokenID(tokenID string) (*User, error) {
s.mux.Lock()
defer s.mux.Unlock()

userID, ok := s.TokenID2UserID[tokenID]
if !ok {
return nil, status.Errorf(status.NotFound, "user not found: provided tokenID doesn't exists")
}

accountID, ok := s.UserID2AccountID[userID]
if !ok {
return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
}

account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}

return account.Users[userID].Copy(), nil
}

// GetAllAccounts returns all accounts
func (s *FileStore) GetAllAccounts() (all []*Account) {
s.mux.Lock()
Expand All @@ -325,8 +393,8 @@ func (s *FileStore) GetAllAccounts() (all []*Account) {

// getAccount returns a reference to the Account. Should not return a copy.
func (s *FileStore) getAccount(accountID string) (*Account, error) {
account, accountFound := s.Accounts[accountID]
if !accountFound {
account, ok := s.Accounts[accountID]
if !ok {
return nil, status.Errorf(status.NotFound, "account not found")
}

Expand All @@ -351,8 +419,8 @@ func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()

accountID, accountIDFound := s.UserID2AccountID[userID]
if !accountIDFound {
accountID, ok := s.UserID2AccountID[userID]
if !ok {
return nil, status.Errorf(status.NotFound, "account not found")
}

Expand All @@ -369,8 +437,8 @@ func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()

accountID, accountIDFound := s.PeerID2AccountID[peerID]
if !accountIDFound {
accountID, ok := s.PeerID2AccountID[peerID]
if !ok {
return nil, status.Errorf(status.NotFound, "provided peer ID doesn't exists %s", peerID)
}

Expand All @@ -395,8 +463,8 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()

accountID, accountIDFound := s.PeerKeyID2AccountID[peerKey]
if !accountIDFound {
accountID, ok := s.PeerKeyID2AccountID[peerKey]
if !ok {
return nil, status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey)
}

Expand Down
Loading

0 comments on commit 94e07bd

Please sign in to comment.