diff --git a/management/server/account.go b/management/server/account.go index 17d2f1486e6..8521be795ee 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -26,11 +26,12 @@ import ( ) const ( - PublicCategory = "public" - PrivateCategory = "private" - UnknownCategory = "unknown" - CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days - CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days + PublicCategory = "public" + PrivateCategory = "private" + UnknownCategory = "unknown" + CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days + CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days + DefaultPeerLoginExpiration = 24 * time.Hour ) func cacheEntryExpiration() time.Duration { @@ -48,6 +49,7 @@ type AccountManager interface { SaveUser(accountID, userID string, update *User) (*UserInfo, error) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) + GetAccountByPeerID(peerID string) (*Account, error) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) AccountExists(accountId string) (*bool, error) @@ -93,6 +95,7 @@ type AccountManager interface { GetDNSSettings(accountID string, userID string) (*DNSSettings, error) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error GetPeer(accountID, peerID, userID string) (*Peer, error) + UpdatePeerLastLogin(peerID string) error } type DefaultAccountManager struct { @@ -134,6 +137,9 @@ type Account struct { Routes map[string]*route.Route NameServerGroups map[string]*nbdns.NameServerGroup DNSSettings *DNSSettings + // PeerLoginExpiration is a setting that indicates when peer login expires. + // Applies to all peers that have Peer.LoginExpirationEnabled set to true. + PeerLoginExpiration time.Duration } type UserInfo struct { @@ -484,6 +490,7 @@ func (a *Account) Copy() *Account { Routes: routes, NameServerGroups: nsGroups, DNSSettings: dnsSettings, + PeerLoginExpiration: a.PeerLoginExpiration, } } @@ -606,6 +613,11 @@ func (am *DefaultAccountManager) warmupIDPCache() error { return nil } +// GetAccountByPeerID returns account from the store by a provided peer ID +func (am *DefaultAccountManager) GetAccountByPeerID(peerID string) (*Account, error) { + return am.Store.GetAccountByPeerID(peerID) +} + // GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and // userID doesn't have an account associated with it, one account is created func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) { @@ -1085,16 +1097,17 @@ func newAccountWithId(accountId, userId, domain string) *Account { log.Debugf("created new account %s with setup key %s", accountId, defaultKey.Key) acc := &Account{ - Id: accountId, - SetupKeys: setupKeys, - Network: network, - Peers: peers, - Users: users, - CreatedBy: userId, - Domain: domain, - Routes: routes, - NameServerGroups: nameServersGroups, - DNSSettings: dnsSettings, + Id: accountId, + SetupKeys: setupKeys, + Network: network, + Peers: peers, + Users: users, + CreatedBy: userId, + Domain: domain, + Routes: routes, + NameServerGroups: nameServersGroups, + DNSSettings: dnsSettings, + PeerLoginExpiration: DefaultPeerLoginExpiration, } addAllGroup(acc) diff --git a/management/server/file_store.go b/management/server/file_store.go index fe71b75788d..602b2e7d867 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -81,6 +81,11 @@ func restore(file string) (*FileStore, error) { store.PeerID2AccountID = make(map[string]string) for accountID, account := range store.Accounts { + + if account.PeerLoginExpiration.Seconds() == 0 { + account.PeerLoginExpiration = DefaultPeerLoginExpiration + } + for setupKeyId := range account.SetupKeys { store.SetupKeyID2AccountID[strings.ToUpper(setupKeyId)] = accountID } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 785a2fdf0f1..039a396e9ce 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -132,6 +132,15 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi return msg } + account, err := s.accountManager.GetAccountByPeerID(peer.ID) + if err != nil { + return status.Error(codes.Internal, "internal server error") + } + expired, left := peer.LoginExpired(account.PeerLoginExpiration) + if peer.UserID != "" && expired { + return status.Errorf(codes.PermissionDenied, "peer login has expired %v ago. Please log in once more", left) + } + syncReq := &proto.SyncRequest{} err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, syncReq) if err != nil { @@ -196,29 +205,37 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi } } +func (s *GRPCServer) validateToken(jwtToken string) (string, error) { + if s.jwtMiddleware == nil { + return "", status.Error(codes.Internal, "no jwt middleware set") + } + + token, err := s.jwtMiddleware.ValidateAndParse(jwtToken) + if err != nil { + return "", status.Errorf(codes.Internal, "invalid jwt token, err: %v", err) + } + claims := s.jwtClaimsExtractor.FromToken(token) + // we need to call this method because if user is new, we will automatically add it to existing or create a new account + _, _, err = s.accountManager.GetAccountFromToken(claims) + if err != nil { + return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err) + } + + return claims.UserId, nil +} + func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Peer, error) { var ( reqSetupKey string userID string + err error ) if req.GetJwtToken() != "" { log.Debugln("using jwt token to register peer") - - if s.jwtMiddleware == nil { - return nil, status.Error(codes.Internal, "no jwt middleware set") - } - - token, err := s.jwtMiddleware.ValidateAndParse(req.GetJwtToken()) - if err != nil { - return nil, status.Errorf(codes.Internal, "invalid jwt token, err: %v", err) - } - claims := s.jwtClaimsExtractor.FromToken(token) - userID = claims.UserId - // we need to call this method because if user is new, we will automatically add it to existing or create a new account - _, _, err = s.accountManager.GetAccountFromToken(claims) + userID, err = s.validateToken(req.JwtToken) if err != nil { - return nil, status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err) + return nil, err } } else { log.Debugln("using setup key to register peer") @@ -354,6 +371,29 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p } } + // check if peer login has expired + account, err := s.accountManager.GetAccountByPeerID(peer.ID) + if err != nil { + return nil, status.Error(codes.Internal, "internal server error") + } + expired, left := peer.LoginExpired(account.PeerLoginExpiration) + if peer.UserID != "" && expired { + // it might be that peer expired but user has logged in already, check token then + if loginReq.GetJwtToken() == "" { + return nil, status.Errorf(codes.PermissionDenied, + "peer login has expired %v ago. Please log in once more", left) + } + _, err = s.validateToken(loginReq.GetJwtToken()) + if err != nil { + return nil, err + } + + err = s.accountManager.UpdatePeerLastLogin(peer.ID) + if err != nil { + return nil, err + } + } + var sshKey []byte if loginReq.GetPeerKeys() != nil { sshKey = loginReq.GetPeerKeys().GetSshPubKey() diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index d5d68ef9ffd..00b29586976 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -67,6 +67,8 @@ type MockAccountManager struct { GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error) SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error GetPeerFunc func(accountID, peerID, userID string) (*server.Peer, error) + GetAccountByPeerIDFunc func(peerID string) (*server.Account, error) + UpdatePeerLastLoginFunc func(peerID string) error } // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface @@ -526,3 +528,19 @@ func (am *MockAccountManager) GetPeer(accountID, peerID, userID string) (*server } return nil, status.Errorf(codes.Unimplemented, "method GetPeer is not implemented") } + +// GetAccountByPeerID mocks GetAccountByPeerID of the AccountManager interface +func (am *MockAccountManager) GetAccountByPeerID(peerID string) (*server.Account, error) { + if am.GetAccountByPeerIDFunc != nil { + return am.GetAccountByPeerIDFunc(peerID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAccountByPeerID is not implemented") +} + +// UpdatePeerLastLogin mocks UpdatePeerLastLogin of the AccountManager interface +func (am *MockAccountManager) UpdatePeerLastLogin(peerID string) error { + if am.UpdatePeerLastLoginFunc != nil { + return am.UpdatePeerLastLoginFunc(peerID) + } + return status.Errorf(codes.Unimplemented, "method UpdatePeerLastLogin is not implemented") +} diff --git a/management/server/peer.go b/management/server/peer.go index db260f999b5..79bbbebac0e 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -58,27 +58,44 @@ type Peer struct { UserID string // SSHKey is a public SSH key of the peer SSHKey string - // SSHEnabled indicated whether SSH server is enabled on the peer + // SSHEnabled indicates whether SSH server is enabled on the peer SSHEnabled bool + // LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login. + // Works with LastLogin + LoginExpirationEnabled bool + // LastLogin the time when peer performed last login operation + LastLogin time.Time } // Copy copies Peer object func (p *Peer) Copy() *Peer { return &Peer{ - ID: p.ID, - Key: p.Key, - SetupKey: p.SetupKey, - IP: p.IP, - Meta: p.Meta, - Name: p.Name, - Status: p.Status, - UserID: p.UserID, - SSHKey: p.SSHKey, - SSHEnabled: p.SSHEnabled, - DNSLabel: p.DNSLabel, + ID: p.ID, + Key: p.Key, + SetupKey: p.SetupKey, + IP: p.IP, + Meta: p.Meta, + Name: p.Name, + Status: p.Status, + UserID: p.UserID, + SSHKey: p.SSHKey, + SSHEnabled: p.SSHEnabled, + DNSLabel: p.DNSLabel, + LoginExpirationEnabled: p.LoginExpirationEnabled, + LastLogin: p.LastLogin, } } +// LoginExpired indicates whether peer's login has expired or not. +// If Peer.LastLogin plus the expiresIn duration has happened already then login has expired. +// Return true if login has expired, false otherwise and time left to expiration (negative when expired). +func (p *Peer) LoginExpired(expiresIn time.Duration) (bool, time.Duration) { + expiresAt := p.LastLogin.Add(expiresIn) + now := time.Now() + left := expiresAt.Sub(now) + return p.LoginExpirationEnabled && (left <= 0), left +} + // FQDN returns peers FQDN combined of the peer's DNS label and the system's DNS domain func (p *Peer) FQDN(dnsDomain string) string { if dnsDomain == "" { @@ -100,7 +117,7 @@ func (p *PeerStatus) Copy() *PeerStatus { } } -// GetPeer looks up peer by its public WireGuard key +// GetPeerByKey looks up peer by its public WireGuard key func (am *DefaultAccountManager) GetPeerByKey(peerPubKey string) (*Peer, error) { account, err := am.Store.GetAccountByPeerPubKey(peerPubKey) @@ -436,17 +453,19 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (* } newPeer := &Peer{ - ID: xid.New().String(), - Key: peer.Key, - SetupKey: upperKey, - IP: nextIp, - Meta: peer.Meta, - Name: peer.Name, - DNSLabel: newLabel, - UserID: userID, - Status: &PeerStatus{Connected: false, LastSeen: time.Now()}, - SSHEnabled: false, - SSHKey: peer.SSHKey, + ID: xid.New().String(), + Key: peer.Key, + SetupKey: upperKey, + IP: nextIp, + Meta: peer.Meta, + Name: peer.Name, + DNSLabel: newLabel, + UserID: userID, + Status: &PeerStatus{Connected: false, LastSeen: time.Now()}, + SSHEnabled: false, + SSHKey: peer.SSHKey, + LastLogin: time.Now(), + LoginExpirationEnabled: false, } // add peer to 'All' group @@ -491,6 +510,38 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (* return newPeer, nil } +// UpdatePeerLastLogin sets Peer.LastLogin to the current timestamp. +func (am *DefaultAccountManager) UpdatePeerLastLogin(peerID string) error { + account, err := am.Store.GetAccountByPeerID(peerID) + if err != nil { + return err + } + + unlock := am.Store.AcquireAccountLock(account.Id) + defer unlock() + + // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) + account, err = am.Store.GetAccount(account.Id) + if err != nil { + return err + } + + peer := account.GetPeer(peerID) + if peer == nil { + return status.Errorf(status.NotFound, "peer with ID %s not found", peerID) + } + + peer.LastLogin = time.Now() + account.UpdatePeer(peer) + + err = am.Store.SaveAccount(account) + if err != nil { + return err + } + + return nil +} + // UpdatePeerSSHKey updates peer's public SSH key func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string) error { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 6afc7c4a85f..e4998f2210e 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -3,11 +3,57 @@ package server import ( "github.com/stretchr/testify/assert" "testing" + "time" "github.com/rs/xid" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +func TestPeer_LoginExpired(t *testing.T) { + + tt := []struct { + name string + expirationEnbaled bool + lastLogin time.Time + expiresIn time.Duration + expected bool + }{ + { + name: "Peer Login Expiration Disabled. Peer Login Should Not Expire", + expirationEnbaled: false, + lastLogin: time.Now().Add(-25 * time.Hour), + expiresIn: time.Hour, + expected: false, + }, + { + name: "Peer Login Should Expire", + expirationEnbaled: true, + lastLogin: time.Now().Add(-25 * time.Hour), + expiresIn: time.Hour, + expected: true, + }, + { + name: "Peer Login Should Not Expire", + expirationEnbaled: true, + lastLogin: time.Now(), + expiresIn: time.Hour, + expected: false, + }, + } + + for _, c := range tt { + t.Run(c.name, func(t *testing.T) { + peer := &Peer{ + LoginExpirationEnabled: c.expirationEnbaled, + LastLogin: c.lastLogin, + } + + expired, _ := peer.LoginExpired(c.expiresIn) + assert.Equal(t, expired, c.expected) + }) + } +} + func TestAccountManager_GetNetworkMap(t *testing.T) { manager, err := createManager(t) if err != nil {