Skip to content

Commit

Permalink
[management] Add integration test for the setup-keys API endpoints (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
pascal-fischer authored Jan 2, 2025
1 parent 03fd656 commit 782e3f8
Show file tree
Hide file tree
Showing 23 changed files with 1,919 additions and 150 deletions.
4 changes: 2 additions & 2 deletions management/cmd/management.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import (
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
httpapi "github.com/netbirdio/netbird/management/server/http"
nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims"
Expand Down Expand Up @@ -281,7 +281,7 @@ var (
routersManager := routers.NewManager(store, permissionsManager, accountManager)
networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager)

httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
httpAPIHandler, err := nbhttp.NewAPIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ type DefaultAccountManager struct {
externalCacheManager ExternalCacheManager
ctx context.Context
eventStore activity.Store
geo *geolocation.Geolocation
geo geolocation.Geolocation

requestBuffer *AccountRequestBuffer

Expand Down Expand Up @@ -244,7 +244,7 @@ func BuildManager(
singleAccountModeDomain string,
dnsDomain string,
eventStore activity.Store,
geo *geolocation.Geolocation,
geo geolocation.Geolocation,
userDeleteFromIDPEnabled bool,
integratedPeerValidator integrated_validator.IntegratedValidator,
metrics telemetry.AppMetrics,
Expand Down
42 changes: 0 additions & 42 deletions management/server/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"

nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
Expand All @@ -38,47 +37,6 @@ import (
"github.com/netbirdio/netbird/route"
)

type MocIntegratedValidator struct {
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error)
}

func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
return nil
}

func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
if a.ValidatePeerFunc != nil {
return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings)
}
return update, false, nil
}
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
validatedPeers := make(map[string]struct{})
for _, peer := range peers {
validatedPeers[peer.ID] = struct{}{}
}
return validatedPeers, nil
}

func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
return peer
}

func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
return false, false, nil
}

func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
return nil
}

func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) {

}

func (MocIntegratedValidator) Stop(_ context.Context) {
}

func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *types.Account, userID string) {
t.Helper()
peer := &nbpeer.Peer{
Expand Down
39 changes: 32 additions & 7 deletions management/server/geolocation/geolocation.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@ import (
log "github.com/sirupsen/logrus"
)

type Geolocation struct {
type Geolocation interface {
Lookup(ip net.IP) (*Record, error)
GetAllCountries() ([]Country, error)
GetCitiesByCountry(countryISOCode string) ([]City, error)
Stop() error
}

type geolocationImpl struct {
mmdbPath string
mux sync.RWMutex
db *maxminddb.Reader
Expand Down Expand Up @@ -54,7 +61,7 @@ const (
geonamesdbPattern = "geonames_*.db"
)

func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (*Geolocation, error) {
func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (Geolocation, error) {
mmdbGlobPattern := filepath.Join(dataDir, mmdbPattern)
mmdbFile, err := getDatabaseFilename(ctx, geoLiteCityTarGZURL, mmdbGlobPattern, autoUpdate)
if err != nil {
Expand Down Expand Up @@ -86,7 +93,7 @@ func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (*Geol
return nil, err
}

geo := &Geolocation{
geo := &geolocationImpl{
mmdbPath: mmdbPath,
mux: sync.RWMutex{},
db: db,
Expand All @@ -113,7 +120,7 @@ func openDB(mmdbPath string) (*maxminddb.Reader, error) {
return db, nil
}

func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) {
func (gl *geolocationImpl) Lookup(ip net.IP) (*Record, error) {
gl.mux.RLock()
defer gl.mux.RUnlock()

Expand All @@ -127,7 +134,7 @@ func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) {
}

// GetAllCountries retrieves a list of all countries.
func (gl *Geolocation) GetAllCountries() ([]Country, error) {
func (gl *geolocationImpl) GetAllCountries() ([]Country, error) {
allCountries, err := gl.locationDB.GetAllCountries()
if err != nil {
return nil, err
Expand All @@ -143,7 +150,7 @@ func (gl *Geolocation) GetAllCountries() ([]Country, error) {
}

// GetCitiesByCountry retrieves a list of cities in a specific country based on the country's ISO code.
func (gl *Geolocation) GetCitiesByCountry(countryISOCode string) ([]City, error) {
func (gl *geolocationImpl) GetCitiesByCountry(countryISOCode string) ([]City, error) {
allCities, err := gl.locationDB.GetCitiesByCountry(countryISOCode)
if err != nil {
return nil, err
Expand All @@ -158,7 +165,7 @@ func (gl *Geolocation) GetCitiesByCountry(countryISOCode string) ([]City, error)
return cities, nil
}

func (gl *Geolocation) Stop() error {
func (gl *geolocationImpl) Stop() error {
close(gl.stopCh)
if gl.db != nil {
if err := gl.db.Close(); err != nil {
Expand Down Expand Up @@ -259,3 +266,21 @@ func cleanupMaxMindDatabases(ctx context.Context, dataDir string, mmdbFile strin
}
return nil
}

type Mock struct{}

func (g *Mock) Lookup(ip net.IP) (*Record, error) {
return &Record{}, nil
}

func (g *Mock) GetAllCountries() ([]Country, error) {
return []Country{}, nil
}

func (g *Mock) GetCitiesByCountry(countryISOCode string) ([]City, error) {
return []City{}, nil
}

func (g *Mock) Stop() error {
return nil
}
2 changes: 1 addition & 1 deletion management/server/geolocation/geolocation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestGeoLite_Lookup(t *testing.T) {
db, err := openDB(filename)
assert.NoError(t, err)

geo := &Geolocation{
geo := &geolocationImpl{
mux: sync.RWMutex{},
db: db,
stopCh: make(chan struct{}),
Expand Down
4 changes: 2 additions & 2 deletions management/server/grpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type GRPCServer struct {
peersUpdateManager *PeersUpdateManager
config *Config
secretsManager SecretsManager
jwtValidator *jwtclaims.JWTValidator
jwtValidator jwtclaims.JWTValidator
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager
Expand All @@ -61,7 +61,7 @@ func NewServer(
return nil, err
}

var jwtValidator *jwtclaims.JWTValidator
var jwtValidator jwtclaims.JWTValidator

if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
jwtValidator, err = jwtclaims.NewJWTValidator(
Expand Down
40 changes: 13 additions & 27 deletions management/server/http/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,8 @@ import (

const apiPrefix = "/api"

type apiHandler struct {
Router *mux.Router
AccountManager s.AccountManager
geolocationManager *geolocation.Geolocation
AuthCfg configs.AuthCfg
}

// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(ctx context.Context, accountManager s.AccountManager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager s.AccountManager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
Expand Down Expand Up @@ -78,27 +71,20 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, networksMa
router := rootRouter.PathPrefix(prefix).Subrouter()
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler)

api := apiHandler{
Router: router,
AccountManager: accountManager,
geolocationManager: LocationManager,
AuthCfg: authCfg,
}

if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil {
if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil {
return nil, fmt.Errorf("register integrations endpoints: %w", err)
}

accounts.AddEndpoints(api.AccountManager, authCfg, router)
peers.AddEndpoints(api.AccountManager, authCfg, router)
users.AddEndpoints(api.AccountManager, authCfg, router)
setup_keys.AddEndpoints(api.AccountManager, authCfg, router)
policies.AddEndpoints(api.AccountManager, api.geolocationManager, authCfg, router)
groups.AddEndpoints(api.AccountManager, authCfg, router)
routes.AddEndpoints(api.AccountManager, authCfg, router)
dns.AddEndpoints(api.AccountManager, authCfg, router)
events.AddEndpoints(api.AccountManager, authCfg, router)
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, api.AccountManager, api.AccountManager.GetAccountIDFromToken, authCfg, router)
accounts.AddEndpoints(accountManager, authCfg, router)
peers.AddEndpoints(accountManager, authCfg, router)
users.AddEndpoints(accountManager, authCfg, router)
setup_keys.AddEndpoints(accountManager, authCfg, router)
policies.AddEndpoints(accountManager, LocationManager, authCfg, router)
groups.AddEndpoints(accountManager, authCfg, router)
routes.AddEndpoints(accountManager, authCfg, router)
dns.AddEndpoints(accountManager, authCfg, router)
events.AddEndpoints(accountManager, authCfg, router)
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, accountManager.GetAccountIDFromToken, authCfg, router)

return rootRouter, nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@ var (
// geolocationsHandler is a handler that returns locations.
type geolocationsHandler struct {
accountManager server.AccountManager
geolocationManager *geolocation.Geolocation
geolocationManager geolocation.Geolocation
claimsExtractor *jwtclaims.ClaimsExtractor
}

func addLocationsEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
func addLocationsEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, authCfg)
router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS")
router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS")
}

// newGeolocationsHandlerHandler creates a new Geolocations handler
func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler {
func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler {
return &geolocationsHandler{
accountManager: accountManager,
geolocationManager: geolocationManager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type handler struct {
claimsExtractor *jwtclaims.ClaimsExtractor
}

func AddEndpoints(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
func AddEndpoints(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
policiesHandler := newHandler(accountManager, authCfg)
router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS")
router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ import (
// postureChecksHandler is a handler that returns posture checks of the account.
type postureChecksHandler struct {
accountManager server.AccountManager
geolocationManager *geolocation.Geolocation
geolocationManager geolocation.Geolocation
claimsExtractor *jwtclaims.ClaimsExtractor
}

func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
postureCheckHandler := newPostureChecksHandler(accountManager, locationManager, authCfg)
router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS")
Expand All @@ -34,7 +34,7 @@ func addPostureCheckEndpoint(accountManager server.AccountManager, locationManag
}

// newPostureChecksHandler creates a new PostureChecks handler
func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler {
func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler {
return &postureChecksHandler{
accountManager: accountManager,
geolocationManager: geolocationManager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksH
return claims.AccountId, claims.UserId, nil
},
},
geolocationManager: &geolocation.Geolocation{},
geolocationManager: &geolocation.Mock{},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
return
}

apiSetupKeys := toResponseBody(setupKey)
apiSetupKeys := ToResponseBody(setupKey)
// for the creation we need to send the plain key
apiSetupKeys.Key = setupKey.Key

Expand Down Expand Up @@ -183,7 +183,7 @@ func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) {

apiSetupKeys := make([]*api.SetupKey, 0)
for _, key := range setupKeys {
apiSetupKeys = append(apiSetupKeys, toResponseBody(key))
apiSetupKeys = append(apiSetupKeys, ToResponseBody(key))
}

util.WriteJSONObject(r.Context(), w, apiSetupKeys)
Expand Down Expand Up @@ -216,14 +216,14 @@ func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) {
func writeSuccess(ctx context.Context, w http.ResponseWriter, key *types.SetupKey) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
err := json.NewEncoder(w).Encode(toResponseBody(key))
err := json.NewEncoder(w).Encode(ToResponseBody(key))
if err != nil {
util.WriteError(ctx, err, w)
return
}
}

func toResponseBody(key *types.SetupKey) *api.SetupKey {
func ToResponseBody(key *types.SetupKey) *api.SetupKey {
var state string
switch {
case key.IsExpired():
Expand Down
Loading

0 comments on commit 782e3f8

Please sign in to comment.