Skip to content

Commit

Permalink
Add FindExistingPostureCheck (#2075)
Browse files Browse the repository at this point in the history
  • Loading branch information
pascal-fischer authored May 30, 2024
1 parent f176807 commit 012235f
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 141 deletions.
5 changes: 5 additions & 0 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ type AccountManager interface {
GetValidatedPeers(account *Account) (map[string]struct{}, error)
SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error)
CancelPeerRoutines(peer *nbpeer.Peer) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
}

type DefaultAccountManager struct {
Expand Down Expand Up @@ -1961,6 +1962,10 @@ func (am *DefaultAccountManager) onPeersInvalidated(accountID string) {
am.updateAccountPeers(updatedAccount)
}

func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
return am.Store.GetPostureCheckByChecksDefinition(accountID, checks)
}

// addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {
Expand Down
5 changes: 5 additions & 0 deletions management/server/file_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"

Expand Down Expand Up @@ -667,6 +668,10 @@ func (s *FileStore) SaveUserLastLogin(accountID, userID string, lastLogin time.T
return nil
}

func (s *FileStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented")
}

// Close the FileStore persisting data to disk
func (s *FileStore) Close() error {
s.mux.Lock()
Expand Down
149 changes: 8 additions & 141 deletions management/server/http/posture_checks_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ package http
import (
"encoding/json"
"net/http"
"net/netip"
"regexp"
"slices"

"github.com/gorilla/mux"
"github.com/rs/xid"

"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
Expand Down Expand Up @@ -59,7 +57,7 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt

postureChecks := []*api.PostureCheck{}
for _, postureCheck := range accountPostureChecks {
postureChecks = append(postureChecks, toPostureChecksResponse(postureCheck))
postureChecks = append(postureChecks, postureCheck.ToAPIResponse())
}

util.WriteJSONObject(w, postureChecks)
Expand Down Expand Up @@ -130,7 +128,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re
return
}

util.WriteJSONObject(w, toPostureChecksResponse(postureChecks))
util.WriteJSONObject(w, postureChecks.ToAPIResponse())
}

// DeletePostureCheck handles posture check deletion request
Expand Down Expand Up @@ -178,55 +176,26 @@ func (p *PostureChecksHandler) savePostureChecks(
return
}

if postureChecksID == "" {
postureChecksID = xid.New().String()
}

postureChecks := posture.Checks{
ID: postureChecksID,
Name: req.Name,
Description: req.Description,
}

if nbVersionCheck := req.Checks.NbVersionCheck; nbVersionCheck != nil {
postureChecks.Checks.NBVersionCheck = &posture.NBVersionCheck{
MinVersion: nbVersionCheck.MinVersion,
}
}

if osVersionCheck := req.Checks.OsVersionCheck; osVersionCheck != nil {
postureChecks.Checks.OSVersionCheck = &posture.OSVersionCheck{
Android: (*posture.MinVersionCheck)(osVersionCheck.Android),
Darwin: (*posture.MinVersionCheck)(osVersionCheck.Darwin),
Ios: (*posture.MinVersionCheck)(osVersionCheck.Ios),
Linux: (*posture.MinKernelVersionCheck)(osVersionCheck.Linux),
Windows: (*posture.MinKernelVersionCheck)(osVersionCheck.Windows),
}
}

if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil {
if p.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
return
}
postureChecks.Checks.GeoLocationCheck = toPostureGeoLocationCheck(geoLocationCheck)
}

if peerNetworkRangeCheck := req.Checks.PeerNetworkRangeCheck; peerNetworkRangeCheck != nil {
postureChecks.Checks.PeerNetworkRangeCheck, err = toPeerNetworkRangeCheck(peerNetworkRangeCheck)
if err != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid network prefix"), w)
return
}
postureChecks, err := posture.NewChecksFromAPIPostureCheckUpdate(req, postureChecksID)
if err != nil {
util.WriteError(err, w)
return
}

if err := p.accountManager.SavePostureChecks(account.Id, user.Id, &postureChecks); err != nil {
if err := p.accountManager.SavePostureChecks(account.Id, user.Id, postureChecks); err != nil {
util.WriteError(err, w)
return
}

util.WriteJSONObject(w, toPostureChecksResponse(&postureChecks))
util.WriteJSONObject(w, postureChecks.ToAPIResponse())
}

func validatePostureChecksUpdate(req api.PostureCheckUpdate) error {
Expand Down Expand Up @@ -294,105 +263,3 @@ func validatePostureChecksUpdate(req api.PostureCheckUpdate) error {

return nil
}

func toPostureChecksResponse(postureChecks *posture.Checks) *api.PostureCheck {
var checks api.Checks

if postureChecks.Checks.NBVersionCheck != nil {
checks.NbVersionCheck = &api.NBVersionCheck{
MinVersion: postureChecks.Checks.NBVersionCheck.MinVersion,
}
}

if postureChecks.Checks.OSVersionCheck != nil {
checks.OsVersionCheck = &api.OSVersionCheck{
Android: (*api.MinVersionCheck)(postureChecks.Checks.OSVersionCheck.Android),
Darwin: (*api.MinVersionCheck)(postureChecks.Checks.OSVersionCheck.Darwin),
Ios: (*api.MinVersionCheck)(postureChecks.Checks.OSVersionCheck.Ios),
Linux: (*api.MinKernelVersionCheck)(postureChecks.Checks.OSVersionCheck.Linux),
Windows: (*api.MinKernelVersionCheck)(postureChecks.Checks.OSVersionCheck.Windows),
}
}

if postureChecks.Checks.GeoLocationCheck != nil {
checks.GeoLocationCheck = toGeoLocationCheckResponse(postureChecks.Checks.GeoLocationCheck)
}

if postureChecks.Checks.PeerNetworkRangeCheck != nil {
checks.PeerNetworkRangeCheck = toPeerNetworkRangeCheckResponse(postureChecks.Checks.PeerNetworkRangeCheck)
}

return &api.PostureCheck{
Id: postureChecks.ID,
Name: postureChecks.Name,
Description: &postureChecks.Description,
Checks: checks,
}
}

func toGeoLocationCheckResponse(geoLocationCheck *posture.GeoLocationCheck) *api.GeoLocationCheck {
locations := make([]api.Location, 0, len(geoLocationCheck.Locations))
for _, loc := range geoLocationCheck.Locations {
l := loc // make G601 happy
var cityName *string
if loc.CityName != "" {
cityName = &l.CityName
}
locations = append(locations, api.Location{
CityName: cityName,
CountryCode: loc.CountryCode,
})
}

return &api.GeoLocationCheck{
Action: api.GeoLocationCheckAction(geoLocationCheck.Action),
Locations: locations,
}
}

func toPostureGeoLocationCheck(apiGeoLocationCheck *api.GeoLocationCheck) *posture.GeoLocationCheck {
locations := make([]posture.Location, 0, len(apiGeoLocationCheck.Locations))
for _, loc := range apiGeoLocationCheck.Locations {
cityName := ""
if loc.CityName != nil {
cityName = *loc.CityName
}
locations = append(locations, posture.Location{
CountryCode: loc.CountryCode,
CityName: cityName,
})
}

return &posture.GeoLocationCheck{
Action: string(apiGeoLocationCheck.Action),
Locations: locations,
}
}

func toPeerNetworkRangeCheckResponse(check *posture.PeerNetworkRangeCheck) *api.PeerNetworkRangeCheck {
netPrefixes := make([]string, 0, len(check.Ranges))
for _, netPrefix := range check.Ranges {
netPrefixes = append(netPrefixes, netPrefix.String())
}

return &api.PeerNetworkRangeCheck{
Ranges: netPrefixes,
Action: api.PeerNetworkRangeCheckAction(check.Action),
}
}

func toPeerNetworkRangeCheck(check *api.PeerNetworkRangeCheck) (*posture.PeerNetworkRangeCheck, error) {
prefixes := make([]netip.Prefix, 0)
for _, prefix := range check.Ranges {
parsedPrefix, err := netip.ParsePrefix(prefix)
if err != nil {
return nil, err
}
prefixes = append(prefixes, parsedPrefix)
}

return &posture.PeerNetworkRangeCheck{
Ranges: prefixes,
Action: string(check.Action),
}, nil
}
9 changes: 9 additions & 0 deletions management/server/mock_server/account_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ type MockAccountManager struct {
GetIdpManagerFunc func() idp.Manager
UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error
GroupValidationFunc func(accountId string, groups []string) (bool, error)
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
}

func (am *MockAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error) {
Expand Down Expand Up @@ -734,3 +735,11 @@ func (am *MockAccountManager) GroupValidation(accountId string, groups []string)
}
return false, status.Errorf(codes.Unimplemented, "method GroupValidation is not implemented")
}

// FindExistingPostureCheck mocks FindExistingPostureCheck of the AccountManager interface
func (am *MockAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
if am.FindExistingPostureCheckFunc != nil {
return am.FindExistingPostureCheckFunc(accountID, checks)
}
return nil, status.Errorf(codes.Unimplemented, "method FindExistingPostureCheck is not implemented")
}
Loading

0 comments on commit 012235f

Please sign in to comment.