diff --git a/management/cmd/management.go b/management/cmd/management.go index 4f34009b7e1..1c8fca8dceb 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -42,7 +42,7 @@ import ( nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/groups" - httpapi "github.com/netbirdio/netbird/management/server/http" + nbhttp "github.com/netbirdio/netbird/management/server/http" "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -281,7 +281,7 @@ var ( routersManager := routers.NewManager(store, permissionsManager, accountManager) networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager) - httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) + httpAPIHandler, err := nbhttp.NewAPIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) if err != nil { return fmt.Errorf("failed creating HTTP API handler: %v", err) } diff --git a/management/server/account.go b/management/server/account.go index 83a8759f9b1..6c8205f26d8 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -161,7 +161,7 @@ type DefaultAccountManager struct { externalCacheManager ExternalCacheManager ctx context.Context eventStore activity.Store - geo *geolocation.Geolocation + geo geolocation.Geolocation requestBuffer *AccountRequestBuffer @@ -244,7 +244,7 @@ func BuildManager( singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, - geo *geolocation.Geolocation, + geo geolocation.Geolocation, userDeleteFromIDPEnabled bool, integratedPeerValidator integrated_validator.IntegratedValidator, metrics telemetry.AppMetrics, diff --git a/management/server/account_test.go b/management/server/account_test.go index 2289c96f90b..4f6cdf78dba 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -27,7 +27,6 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -38,47 +37,6 @@ import ( "github.com/netbirdio/netbird/route" ) -type MocIntegratedValidator struct { - ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) -} - -func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { - return nil -} - -func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) { - if a.ValidatePeerFunc != nil { - return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings) - } - return update, false, nil -} -func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { - validatedPeers := make(map[string]struct{}) - for _, peer := range peers { - validatedPeers[peer.ID] = struct{}{} - } - return validatedPeers, nil -} - -func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { - return peer -} - -func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) { - return false, false, nil -} - -func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { - return nil -} - -func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { - -} - -func (MocIntegratedValidator) Stop(_ context.Context) { -} - func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *types.Account, userID string) { t.Helper() peer := &nbpeer.Peer{ diff --git a/management/server/geolocation/geolocation.go b/management/server/geolocation/geolocation.go index 553a3158187..c0179a1c481 100644 --- a/management/server/geolocation/geolocation.go +++ b/management/server/geolocation/geolocation.go @@ -14,7 +14,14 @@ import ( log "github.com/sirupsen/logrus" ) -type Geolocation struct { +type Geolocation interface { + Lookup(ip net.IP) (*Record, error) + GetAllCountries() ([]Country, error) + GetCitiesByCountry(countryISOCode string) ([]City, error) + Stop() error +} + +type geolocationImpl struct { mmdbPath string mux sync.RWMutex db *maxminddb.Reader @@ -54,7 +61,7 @@ const ( geonamesdbPattern = "geonames_*.db" ) -func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (*Geolocation, error) { +func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (Geolocation, error) { mmdbGlobPattern := filepath.Join(dataDir, mmdbPattern) mmdbFile, err := getDatabaseFilename(ctx, geoLiteCityTarGZURL, mmdbGlobPattern, autoUpdate) if err != nil { @@ -86,7 +93,7 @@ func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (*Geol return nil, err } - geo := &Geolocation{ + geo := &geolocationImpl{ mmdbPath: mmdbPath, mux: sync.RWMutex{}, db: db, @@ -113,7 +120,7 @@ func openDB(mmdbPath string) (*maxminddb.Reader, error) { return db, nil } -func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) { +func (gl *geolocationImpl) Lookup(ip net.IP) (*Record, error) { gl.mux.RLock() defer gl.mux.RUnlock() @@ -127,7 +134,7 @@ func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) { } // GetAllCountries retrieves a list of all countries. -func (gl *Geolocation) GetAllCountries() ([]Country, error) { +func (gl *geolocationImpl) GetAllCountries() ([]Country, error) { allCountries, err := gl.locationDB.GetAllCountries() if err != nil { return nil, err @@ -143,7 +150,7 @@ func (gl *Geolocation) GetAllCountries() ([]Country, error) { } // GetCitiesByCountry retrieves a list of cities in a specific country based on the country's ISO code. -func (gl *Geolocation) GetCitiesByCountry(countryISOCode string) ([]City, error) { +func (gl *geolocationImpl) GetCitiesByCountry(countryISOCode string) ([]City, error) { allCities, err := gl.locationDB.GetCitiesByCountry(countryISOCode) if err != nil { return nil, err @@ -158,7 +165,7 @@ func (gl *Geolocation) GetCitiesByCountry(countryISOCode string) ([]City, error) return cities, nil } -func (gl *Geolocation) Stop() error { +func (gl *geolocationImpl) Stop() error { close(gl.stopCh) if gl.db != nil { if err := gl.db.Close(); err != nil { @@ -259,3 +266,21 @@ func cleanupMaxMindDatabases(ctx context.Context, dataDir string, mmdbFile strin } return nil } + +type Mock struct{} + +func (g *Mock) Lookup(ip net.IP) (*Record, error) { + return &Record{}, nil +} + +func (g *Mock) GetAllCountries() ([]Country, error) { + return []Country{}, nil +} + +func (g *Mock) GetCitiesByCountry(countryISOCode string) ([]City, error) { + return []City{}, nil +} + +func (g *Mock) Stop() error { + return nil +} diff --git a/management/server/geolocation/geolocation_test.go b/management/server/geolocation/geolocation_test.go index 9bdefd268ac..fecd715be07 100644 --- a/management/server/geolocation/geolocation_test.go +++ b/management/server/geolocation/geolocation_test.go @@ -24,7 +24,7 @@ func TestGeoLite_Lookup(t *testing.T) { db, err := openDB(filename) assert.NoError(t, err) - geo := &Geolocation{ + geo := &geolocationImpl{ mux: sync.RWMutex{}, db: db, stopCh: make(chan struct{}), diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 2635ac11b0a..daa23d2abfe 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -38,7 +38,7 @@ type GRPCServer struct { peersUpdateManager *PeersUpdateManager config *Config secretsManager SecretsManager - jwtValidator *jwtclaims.JWTValidator + jwtValidator jwtclaims.JWTValidator jwtClaimsExtractor *jwtclaims.ClaimsExtractor appMetrics telemetry.AppMetrics ephemeralManager *EphemeralManager @@ -61,7 +61,7 @@ func NewServer( return nil, err } - var jwtValidator *jwtclaims.JWTValidator + var jwtValidator jwtclaims.JWTValidator if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) { jwtValidator, err = jwtclaims.NewJWTValidator( diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 7db7ab5b842..cc2ad00b73d 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -35,15 +35,8 @@ import ( const apiPrefix = "/api" -type apiHandler struct { - Router *mux.Router - AccountManager s.AccountManager - geolocationManager *geolocation.Geolocation - AuthCfg configs.AuthCfg -} - -// APIHandler creates the Management service HTTP API handler registering all the available endpoints. -func APIHandler(ctx context.Context, accountManager s.AccountManager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { +// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. +func NewAPIHandler(ctx context.Context, accountManager s.AccountManager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { claimsExtractor := jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), @@ -78,27 +71,20 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, networksMa router := rootRouter.PathPrefix(prefix).Subrouter() router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler) - api := apiHandler{ - Router: router, - AccountManager: accountManager, - geolocationManager: LocationManager, - AuthCfg: authCfg, - } - - if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil { + if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil { return nil, fmt.Errorf("register integrations endpoints: %w", err) } - accounts.AddEndpoints(api.AccountManager, authCfg, router) - peers.AddEndpoints(api.AccountManager, authCfg, router) - users.AddEndpoints(api.AccountManager, authCfg, router) - setup_keys.AddEndpoints(api.AccountManager, authCfg, router) - policies.AddEndpoints(api.AccountManager, api.geolocationManager, authCfg, router) - groups.AddEndpoints(api.AccountManager, authCfg, router) - routes.AddEndpoints(api.AccountManager, authCfg, router) - dns.AddEndpoints(api.AccountManager, authCfg, router) - events.AddEndpoints(api.AccountManager, authCfg, router) - networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, api.AccountManager, api.AccountManager.GetAccountIDFromToken, authCfg, router) + accounts.AddEndpoints(accountManager, authCfg, router) + peers.AddEndpoints(accountManager, authCfg, router) + users.AddEndpoints(accountManager, authCfg, router) + setup_keys.AddEndpoints(accountManager, authCfg, router) + policies.AddEndpoints(accountManager, LocationManager, authCfg, router) + groups.AddEndpoints(accountManager, authCfg, router) + routes.AddEndpoints(accountManager, authCfg, router) + dns.AddEndpoints(accountManager, authCfg, router) + events.AddEndpoints(accountManager, authCfg, router) + networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, accountManager.GetAccountIDFromToken, authCfg, router) return rootRouter, nil } diff --git a/management/server/http/handlers/policies/geolocations_handler.go b/management/server/http/handlers/policies/geolocations_handler.go index e5bf3e6952d..161d974022a 100644 --- a/management/server/http/handlers/policies/geolocations_handler.go +++ b/management/server/http/handlers/policies/geolocations_handler.go @@ -22,18 +22,18 @@ var ( // geolocationsHandler is a handler that returns locations. type geolocationsHandler struct { accountManager server.AccountManager - geolocationManager *geolocation.Geolocation + geolocationManager geolocation.Geolocation claimsExtractor *jwtclaims.ClaimsExtractor } -func addLocationsEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { +func addLocationsEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, authCfg) router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS") router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS") } // newGeolocationsHandlerHandler creates a new Geolocations handler -func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler { +func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler { return &geolocationsHandler{ accountManager: accountManager, geolocationManager: geolocationManager, diff --git a/management/server/http/handlers/policies/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go index b1035c5701f..a748e73b8ed 100644 --- a/management/server/http/handlers/policies/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -23,7 +23,7 @@ type handler struct { claimsExtractor *jwtclaims.ClaimsExtractor } -func AddEndpoints(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { +func AddEndpoints(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { policiesHandler := newHandler(accountManager, authCfg) router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS") router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS") diff --git a/management/server/http/handlers/policies/posture_checks_handler.go b/management/server/http/handlers/policies/posture_checks_handler.go index 44917605ba2..ce0d4878c92 100644 --- a/management/server/http/handlers/policies/posture_checks_handler.go +++ b/management/server/http/handlers/policies/posture_checks_handler.go @@ -19,11 +19,11 @@ import ( // postureChecksHandler is a handler that returns posture checks of the account. type postureChecksHandler struct { accountManager server.AccountManager - geolocationManager *geolocation.Geolocation + geolocationManager geolocation.Geolocation claimsExtractor *jwtclaims.ClaimsExtractor } -func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { +func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { postureCheckHandler := newPostureChecksHandler(accountManager, locationManager, authCfg) router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS") router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS") @@ -34,7 +34,7 @@ func addPostureCheckEndpoint(accountManager server.AccountManager, locationManag } // newPostureChecksHandler creates a new PostureChecks handler -func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler { +func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler { return &postureChecksHandler{ accountManager: accountManager, geolocationManager: geolocationManager, diff --git a/management/server/http/handlers/policies/posture_checks_handler_test.go b/management/server/http/handlers/policies/posture_checks_handler_test.go index e9a539e450a..237687fd4a4 100644 --- a/management/server/http/handlers/policies/posture_checks_handler_test.go +++ b/management/server/http/handlers/policies/posture_checks_handler_test.go @@ -70,7 +70,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksH return claims.AccountId, claims.UserId, nil }, }, - geolocationManager: &geolocation.Geolocation{}, + geolocationManager: &geolocation.Mock{}, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { return jwtclaims.AuthorizationClaims{ diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler.go b/management/server/http/handlers/setup_keys/setupkeys_handler.go index 89696a16563..a627d72033f 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler.go @@ -93,7 +93,7 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) { return } - apiSetupKeys := toResponseBody(setupKey) + apiSetupKeys := ToResponseBody(setupKey) // for the creation we need to send the plain key apiSetupKeys.Key = setupKey.Key @@ -183,7 +183,7 @@ func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) { apiSetupKeys := make([]*api.SetupKey, 0) for _, key := range setupKeys { - apiSetupKeys = append(apiSetupKeys, toResponseBody(key)) + apiSetupKeys = append(apiSetupKeys, ToResponseBody(key)) } util.WriteJSONObject(r.Context(), w, apiSetupKeys) @@ -216,14 +216,14 @@ func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) { func writeSuccess(ctx context.Context, w http.ResponseWriter, key *types.SetupKey) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) - err := json.NewEncoder(w).Encode(toResponseBody(key)) + err := json.NewEncoder(w).Encode(ToResponseBody(key)) if err != nil { util.WriteError(ctx, err, w) return } } -func toResponseBody(key *types.SetupKey) *api.SetupKey { +func ToResponseBody(key *types.SetupKey) *api.SetupKey { var state string switch { case key.IsExpired(): diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go index 4ecb1e9ed4c..f56227c10dc 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go @@ -26,7 +26,6 @@ const ( newSetupKeyName = "New Setup Key" updatedSetupKeyName = "KKKey" notFoundSetupKeyID = "notFoundSetupKeyID" - testAccountID = "test_id" ) func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey, @@ -81,7 +80,7 @@ func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKe return jwtclaims.AuthorizationClaims{ UserId: user.Id, Domain: "hotmail.com", - AccountId: testAccountID, + AccountId: "testAccountId", } }), ), @@ -102,7 +101,7 @@ func TestSetupKeysHandlers(t *testing.T) { updatedDefaultSetupKey.Name = updatedSetupKeyName updatedDefaultSetupKey.Revoked = true - expectedNewKey := toResponseBody(newSetupKey) + expectedNewKey := ToResponseBody(newSetupKey) expectedNewKey.Key = plainKey tt := []struct { name string @@ -120,7 +119,7 @@ func TestSetupKeysHandlers(t *testing.T) { requestPath: "/api/setup-keys", expectedStatus: http.StatusOK, expectedBody: true, - expectedSetupKeys: []*api.SetupKey{toResponseBody(defaultSetupKey)}, + expectedSetupKeys: []*api.SetupKey{ToResponseBody(defaultSetupKey)}, }, { name: "Get Existing Setup Key", @@ -128,7 +127,7 @@ func TestSetupKeysHandlers(t *testing.T) { requestPath: "/api/setup-keys/" + existingSetupKeyID, expectedStatus: http.StatusOK, expectedBody: true, - expectedSetupKey: toResponseBody(defaultSetupKey), + expectedSetupKey: ToResponseBody(defaultSetupKey), }, { name: "Get Not Existing Setup Key", @@ -159,7 +158,7 @@ func TestSetupKeysHandlers(t *testing.T) { ))), expectedStatus: http.StatusOK, expectedBody: true, - expectedSetupKey: toResponseBody(updatedDefaultSetupKey), + expectedSetupKey: ToResponseBody(updatedDefaultSetupKey), }, { name: "Delete Setup Key", @@ -228,7 +227,7 @@ func TestSetupKeysHandlers(t *testing.T) { func assertKeys(t *testing.T, got *api.SetupKey, expected *api.SetupKey) { t.Helper() // this comparison is done manually because when converting to JSON dates formatted differently - // assert.Equal(t, got.UpdatedAt, tc.expectedSetupKey.UpdatedAt) //doesn't work + // assert.Equal(t, got.UpdatedAt, tc.expectedResponse.UpdatedAt) //doesn't work assert.WithinDurationf(t, got.UpdatedAt, expected.UpdatedAt, 0, "") assert.WithinDurationf(t, got.Expires, expected.Expires, 0, "") assert.Equal(t, got.Name, expected.Name) diff --git a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go new file mode 100644 index 00000000000..5e2895bcc37 --- /dev/null +++ b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go @@ -0,0 +1,226 @@ +package benchmarks + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "strconv" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" +) + +// Map to store peers, groups, users, and setupKeys by name +var benchCasesSetupKeys = map[string]testing_tools.BenchmarkCase{ + "Setup Keys - XS": {Peers: 10000, Groups: 10000, Users: 10000, SetupKeys: 5}, + "Setup Keys - S": {Peers: 5, Groups: 5, Users: 5, SetupKeys: 100}, + "Setup Keys - M": {Peers: 100, Groups: 20, Users: 20, SetupKeys: 1000}, + "Setup Keys - L": {Peers: 5, Groups: 5, Users: 5, SetupKeys: 5000}, + "Peers - L": {Peers: 10000, Groups: 5, Users: 5, SetupKeys: 5000}, + "Groups - L": {Peers: 5, Groups: 10000, Users: 5, SetupKeys: 5000}, + "Users - L": {Peers: 5, Groups: 5, Users: 10000, SetupKeys: 5000}, + "Setup Keys - XL": {Peers: 500, Groups: 50, Users: 100, SetupKeys: 25000}, +} + +func BenchmarkCreateSetupKey(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesSetupKeys { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + requestBody := api.CreateSetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName + strconv.Itoa(i), + Type: "reusable", + UsageLimit: 0, + } + + // the time marshal will be recorded as well but for our use case that is ok + body, err := json.Marshal(requestBody) + assert.NoError(b, err) + + req := testing_tools.BuildRequest(b, body, http.MethodPost, "/api/setup-keys", testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkUpdateSetupKey(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesSetupKeys { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + groupId := testing_tools.TestGroupId + if i%2 == 0 { + groupId = testing_tools.NewGroupId + } + requestBody := api.SetupKeyRequest{ + AutoGroups: []string{groupId}, + Revoked: false, + } + + // the time marshal will be recorded as well but for our use case that is ok + body, err := json.Marshal(requestBody) + assert.NoError(b, err) + + req := testing_tools.BuildRequest(b, body, http.MethodPut, "/api/setup-keys/"+testing_tools.TestKeyId, testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkGetOneSetupKey(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesSetupKeys { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/setup-keys/"+testing_tools.TestKeyId, testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkGetAllSetupKeys(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12}, + "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 15}, + "Setup Keys - M": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 40}, + "Setup Keys - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150}, + "Peers - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150}, + "Groups - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150}, + "Users - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150}, + "Setup Keys - XL": {MinMsPerOpLocal: 140, MaxMsPerOpLocal: 220, MinMsPerOpCICD: 150, MaxMsPerOpCICD: 500}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesSetupKeys { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/setup-keys", testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkDeleteSetupKey(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesSetupKeys { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, 1000) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodDelete, "/api/setup-keys/"+"oldkey-"+strconv.Itoa(i), testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} diff --git a/management/server/http/testing/integration/setupkeys_handler_integration_test.go b/management/server/http/testing/integration/setupkeys_handler_integration_test.go new file mode 100644 index 00000000000..193c0fb022d --- /dev/null +++ b/management/server/http/testing/integration/setupkeys_handler_integration_test.go @@ -0,0 +1,1146 @@ +package integration + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sort" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/handlers/setup_keys" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" +) + +func Test_SetupKeys_Create(t *testing.T) { + truePointer := true + + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + expectedResponse *api.SetupKey + requestBody *api.CreateSetupKeyRequest + requestType string + requestPath string + userId string + }{ + { + name: "Create Setup Key", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: nil, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "reusable", + UsageLimit: 0, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.NewKeyName, + Revoked: false, + State: "valid", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 0, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Create Setup Key with already existing name", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: nil, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.ExistingKeyName, + Type: "one-off", + UsageLimit: 0, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Create Setup Key as on-off with more than one usage", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: nil, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "one-off", + UsageLimit: 3, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.NewKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Create Setup Key with expiration in the past", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: nil, + ExpiresIn: -testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "one-off", + UsageLimit: 0, + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedResponse: nil, + }, + { + name: "Create Setup Key with AutoGroups that do exist", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "reusable", + UsageLimit: 1, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.NewKeyName, + Revoked: false, + State: "valid", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Create Setup Key for ephemeral peers", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: []string{}, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "reusable", + Ephemeral: &truePointer, + UsageLimit: 1, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: true, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.NewKeyName, + Revoked: false, + State: "valid", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Create Setup Key with AutoGroups that do not exist", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: []string{"someGroupID"}, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "reusable", + UsageLimit: 0, + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedResponse: nil, + }, + { + name: "Create Setup Key", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: nil, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "reusable", + UsageLimit: 0, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.NewKeyName, + Revoked: false, + State: "valid", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 0, + UsedTimes: 0, + Valid: true, + }, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + req := testing_tools.BuildRequest(t, body, tc.requestType, tc.requestPath, user.userId) + + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + got := &api.SetupKey{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + validateCreatedKey(t, tc.expectedResponse, got) + + key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) + if err != nil { + return + } + + validateCreatedKey(t, tc.expectedResponse, setup_keys.ToResponseBody(key)) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_SetupKeys_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + expectedResponse *api.SetupKey + requestBody *api.SetupKeyRequest + requestType string + requestPath string + requestId string + }{ + { + name: "Add existing Group to existing Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId, testing_tools.NewGroupId}, + Revoked: false, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId, testing_tools.NewGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Add non-existing Group to existing Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId, "someGroupId"}, + Revoked: false, + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedResponse: nil, + }, + { + name: "Add existing Group to non-existing Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: "someId", + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId, testing_tools.NewGroupId}, + Revoked: false, + }, + expectedStatus: http.StatusNotFound, + expectedResponse: nil, + }, + { + name: "Remove existing Group from existing Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{}, + Revoked: false, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Remove existing Group to non-existing Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: "someID", + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{}, + Revoked: false, + }, + expectedStatus: http.StatusNotFound, + expectedResponse: nil, + }, + { + name: "Revoke existing valid Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + Revoked: true, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "revoked", + Type: "one-off", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: false, + }, + }, + { + name: "Revoke existing revoked Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.RevokedKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + Revoked: true, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "revoked", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 3, + UsedTimes: 0, + Valid: false, + }, + }, + { + name: "Un-Revoke existing revoked Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.RevokedKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + Revoked: false, + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedResponse: nil, + }, + { + name: "Revoke existing expired Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.ExpiredKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + Revoked: true, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: true, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "expired", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 5, + UsedTimes: 1, + Valid: false, + }, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(tc.name, func(t *testing.T) { + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) + + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + got := &api.SetupKey{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + validateCreatedKey(t, tc.expectedResponse, got) + + key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) + if err != nil { + return + } + + validateCreatedKey(t, tc.expectedResponse, setup_keys.ToResponseBody(key)) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_SetupKeys_Get(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + expectedResponse *api.SetupKey + requestType string + requestPath string + requestId string + }{ + { + name: "Get existing valid Setup Key", + requestType: http.MethodGet, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Get existing expired Setup Key", + requestType: http.MethodGet, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.ExpiredKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: true, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "expired", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 5, + UsedTimes: 1, + Valid: false, + }, + }, + { + name: "Get existing revoked Setup Key", + requestType: http.MethodGet, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.RevokedKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "revoked", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 3, + UsedTimes: 0, + Valid: false, + }, + }, + { + name: "Get non-existing Setup Key", + requestType: http.MethodGet, + requestPath: "/api/setup-keys/{id}", + requestId: "someId", + expectedStatus: http.StatusNotFound, + expectedResponse: nil, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(tc.name, func(t *testing.T) { + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil) + + req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) + + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectRespnose := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectRespnose { + return + } + got := &api.SetupKey{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + validateCreatedKey(t, tc.expectedResponse, got) + + key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) + if err != nil { + return + } + + validateCreatedKey(t, tc.expectedResponse, setup_keys.ToResponseBody(key)) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_SetupKeys_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + expectedResponse []*api.SetupKey + requestType string + requestPath string + }{ + { + name: "Get all Setup Keys", + requestType: http.MethodGet, + requestPath: "/api/setup-keys", + expectedStatus: http.StatusOK, + expectedResponse: []*api.SetupKey{ + { + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + { + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "revoked", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 3, + UsedTimes: 0, + Valid: false, + }, + { + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: true, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "expired", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 5, + UsedTimes: 1, + Valid: false, + }, + }, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(tc.name, func(t *testing.T) { + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil) + + req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, tc.requestPath, user.userId) + + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + got := []api.SetupKey{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + sort.Slice(got, func(i, j int) bool { + return got[i].UsageLimit < got[j].UsageLimit + }) + + sort.Slice(tc.expectedResponse, func(i, j int) bool { + return tc.expectedResponse[i].UsageLimit < tc.expectedResponse[j].UsageLimit + }) + + for i, _ := range tc.expectedResponse { + validateCreatedKey(t, tc.expectedResponse[i], &got[i]) + + key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got[i].Id) + if err != nil { + return + } + + validateCreatedKey(t, tc.expectedResponse[i], setup_keys.ToResponseBody(key)) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_SetupKeys_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + expectedResponse *api.SetupKey + requestType string + requestPath string + requestId string + }{ + { + name: "Delete existing valid Setup Key", + requestType: http.MethodDelete, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Delete existing expired Setup Key", + requestType: http.MethodDelete, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.ExpiredKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: true, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "expired", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 5, + UsedTimes: 1, + Valid: false, + }, + }, + { + name: "Delete existing revoked Setup Key", + requestType: http.MethodDelete, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.RevokedKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "revoked", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 3, + UsedTimes: 0, + Valid: false, + }, + }, + { + name: "Delete non-existing Setup Key", + requestType: http.MethodDelete, + requestPath: "/api/setup-keys/{id}", + requestId: "someId", + expectedStatus: http.StatusNotFound, + expectedResponse: nil, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(tc.name, func(t *testing.T) { + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil) + + req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) + + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + got := &api.SetupKey{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + _, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) + assert.Errorf(t, err, "Expected error when trying to get deleted key") + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func validateCreatedKey(t *testing.T, expectedKey *api.SetupKey, got *api.SetupKey) { + t.Helper() + + if got.Expires.After(time.Now().Add(-1*time.Minute)) && got.Expires.Before(time.Now().Add(testing_tools.ExpiresIn*time.Second)) || + got.Expires.After(time.Date(2300, 01, 01, 0, 0, 0, 0, time.Local)) || + got.Expires.Before(time.Date(1950, 01, 01, 0, 0, 0, 0, time.Local)) { + got.Expires = time.Time{} + expectedKey.Expires = time.Time{} + } + + if got.Id == "" { + t.Fatalf("Expected key to have an ID") + } + got.Id = "" + + if got.Key == "" { + t.Fatalf("Expected key to have a key") + } + got.Key = "" + + if got.UpdatedAt.After(time.Now().Add(-1*time.Minute)) && got.UpdatedAt.Before(time.Now().Add(+1*time.Minute)) { + got.UpdatedAt = time.Time{} + expectedKey.UpdatedAt = time.Time{} + } + + expectedKey.UpdatedAt = expectedKey.UpdatedAt.In(time.UTC) + got.UpdatedAt = got.UpdatedAt.In(time.UTC) + + assert.Equal(t, expectedKey, got) +} diff --git a/management/server/http/testing/testdata/setup_keys.sql b/management/server/http/testing/testdata/setup_keys.sql new file mode 100644 index 00000000000..a315ea0f701 --- /dev/null +++ b/management/server/http/testing/testdata/setup_keys.sql @@ -0,0 +1,24 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); + + +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,'0001-01-01 00:00:00+00:00','["testGroupId"]',1,0); +INSERT INTO setup_keys VALUES('revokedKeyId','testAccountId','revokedKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',1,0,'0001-01-01 00:00:00+00:00','["testGroupId"]',3,0); +INSERT INTO setup_keys VALUES('expiredKeyId','testAccountId','expiredKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','1921-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,1,'0001-01-01 00:00:00+00:00','["testGroupId"]',5,1); + diff --git a/management/server/http/testing/testing_tools/tools.go b/management/server/http/testing/testing_tools/tools.go new file mode 100644 index 00000000000..da910c5c35a --- /dev/null +++ b/management/server/http/testing/testing_tools/tools.go @@ -0,0 +1,307 @@ +package testing_tools + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "os" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/groups" + nbhttp "github.com/netbirdio/netbird/management/server/http" + "github.com/netbirdio/netbird/management/server/http/configs" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/networks" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" +) + +const ( + TestAccountId = "testAccountId" + TestPeerId = "testPeerId" + TestGroupId = "testGroupId" + TestKeyId = "testKeyId" + + TestUserId = "testUserId" + TestAdminId = "testAdminId" + TestOwnerId = "testOwnerId" + TestServiceUserId = "testServiceUserId" + TestServiceAdminId = "testServiceAdminId" + BlockedUserId = "blockedUserId" + OtherUserId = "otherUserId" + InvalidToken = "invalidToken" + + NewKeyName = "newKey" + NewGroupId = "newGroupId" + ExpiresIn = 3600 + RevokedKeyId = "revokedKeyId" + ExpiredKeyId = "expiredKeyId" + + ExistingKeyName = "existingKey" +) + +type TB interface { + Cleanup(func()) + Helper() + Errorf(format string, args ...any) + Fatalf(format string, args ...any) + TempDir() string +} + +// BenchmarkCase defines a single benchmark test case +type BenchmarkCase struct { + Peers int + Groups int + Users int + SetupKeys int +} + +// PerformanceMetrics holds the performance expectations +type PerformanceMetrics struct { + MinMsPerOpLocal float64 + MaxMsPerOpLocal float64 + MinMsPerOpCICD float64 + MaxMsPerOpCICD float64 +} + +func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage) (http.Handler, server.AccountManager, chan struct{}) { + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir()) + if err != nil { + t.Fatalf("Failed to create test store: %v", err) + } + t.Cleanup(cleanup) + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + if err != nil { + t.Fatalf("Failed to create metrics: %v", err) + } + + peersUpdateManager := server.NewPeersUpdateManager(nil) + updMsg := peersUpdateManager.CreateChannel(context.Background(), TestPeerId) + done := make(chan struct{}) + go func() { + if expectedPeerUpdate != nil { + peerShouldReceiveUpdate(t, updMsg, expectedPeerUpdate) + } else { + peerShouldNotReceiveUpdate(t, updMsg) + } + close(done) + }() + + geoMock := &geolocation.Mock{} + validatorMock := server.MocIntegratedValidator{} + am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + + networksManagerMock := networks.NewManagerMock() + resourcesManagerMock := resources.NewManagerMock() + routersManagerMock := routers.NewManagerMock() + groupsManagerMock := groups.NewManagerMock() + apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, &jwtclaims.JwtValidatorMock{}, metrics, configs.AuthCfg{}, validatorMock) + if err != nil { + t.Fatalf("Failed to create API handler: %v", err) + } + + return apiHandler, am, done +} + +func peerShouldNotReceiveUpdate(t TB, updateMessage <-chan *server.UpdateMessage) { + t.Helper() + select { + case msg := <-updateMessage: + t.Errorf("Unexpected message received: %+v", msg) + case <-time.After(500 * time.Millisecond): + return + } +} + +func peerShouldReceiveUpdate(t TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) { + t.Helper() + + select { + case msg := <-updateMessage: + if msg == nil { + t.Errorf("Received nil update message, expected valid message") + } + assert.Equal(t, expected, msg) + case <-time.After(500 * time.Millisecond): + t.Errorf("Timed out waiting for update message") + } +} + +func BuildRequest(t TB, requestBody []byte, requestType, requestPath, user string) *http.Request { + t.Helper() + + req := httptest.NewRequest(requestType, requestPath, bytes.NewBuffer(requestBody)) + req.Header.Set("Authorization", "Bearer "+user) + + return req +} + +func ReadResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedStatus int, expectResponse bool) ([]byte, bool) { + t.Helper() + + res := recorder.Result() + defer res.Body.Close() + + content, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + if !expectResponse { + return nil, false + } + + if status := recorder.Code; status != expectedStatus { + t.Fatalf("handler returned wrong status code: got %v want %v, content: %s", + status, expectedStatus, string(content)) + } + + return content, expectedStatus == http.StatusOK +} + +func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, groups, users, setupKeys int) { + b.Helper() + + ctx := context.Background() + account, err := am.GetAccount(ctx, TestAccountId) + if err != nil { + b.Fatalf("Failed to get account: %v", err) + } + + // Create peers + for i := 0; i < peers; i++ { + peerKey, _ := wgtypes.GeneratePrivateKey() + peer := &nbpeer.Peer{ + ID: fmt.Sprintf("oldpeer-%d", i), + DNSLabel: fmt.Sprintf("oldpeer-%d", i), + Key: peerKey.PublicKey().String(), + IP: net.ParseIP(fmt.Sprintf("100.64.%d.%d", i/256, i%256)), + Status: &nbpeer.PeerStatus{}, + UserID: TestUserId, + } + account.Peers[peer.ID] = peer + } + + // Create users + for i := 0; i < users; i++ { + user := &types.User{ + Id: fmt.Sprintf("olduser-%d", i), + AccountID: account.Id, + Role: types.UserRoleUser, + } + account.Users[user.Id] = user + } + + for i := 0; i < setupKeys; i++ { + key := &types.SetupKey{ + Id: fmt.Sprintf("oldkey-%d", i), + AccountID: account.Id, + AutoGroups: []string{"someGroupID"}, + ExpiresAt: time.Now().Add(ExpiresIn * time.Second), + Name: NewKeyName + strconv.Itoa(i), + Type: "reusable", + UsageLimit: 0, + } + account.SetupKeys[key.Id] = key + } + + // Create groups and policies + account.Policies = make([]*types.Policy, 0, groups) + for i := 0; i < groups; i++ { + groupID := fmt.Sprintf("group-%d", i) + group := &types.Group{ + ID: groupID, + Name: fmt.Sprintf("Group %d", i), + } + for j := 0; j < peers/groups; j++ { + peerIndex := i*(peers/groups) + j + group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex)) + } + account.Groups[groupID] = group + + // Create a policy for this group + policy := &types.Policy{ + ID: fmt.Sprintf("policy-%d", i), + Name: fmt.Sprintf("Policy for Group %d", i), + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: fmt.Sprintf("rule-%d", i), + Name: fmt.Sprintf("Rule for Group %d", i), + Enabled: true, + Sources: []string{groupID}, + Destinations: []string{groupID}, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, + }, + }, + } + account.Policies = append(account.Policies, policy) + } + + account.PostureChecks = []*posture.Checks{ + { + ID: "PostureChecksAll", + Name: "All", + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.0.1", + }, + }, + }, + } + + err = am.Store.SaveAccount(context.Background(), account) + if err != nil { + b.Fatalf("Failed to save account: %v", err) + } + +} + +func EvaluateBenchmarkResults(b *testing.B, name string, duration time.Duration, perfMetrics PerformanceMetrics, recorder *httptest.ResponseRecorder) { + b.Helper() + + if recorder.Code != http.StatusOK { + b.Fatalf("Benchmark %s failed: unexpected status code %d", name, recorder.Code) + } + + msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 + b.ReportMetric(msPerOp, "ms/op") + + minExpected := perfMetrics.MinMsPerOpLocal + maxExpected := perfMetrics.MaxMsPerOpLocal + if os.Getenv("CI") == "true" { + minExpected = perfMetrics.MinMsPerOpCICD + maxExpected = perfMetrics.MaxMsPerOpCICD + } + + if msPerOp < minExpected { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", name, msPerOp, minExpected) + } + + if msPerOp > maxExpected { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", name, msPerOp, maxExpected) + } +} diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 47c4ca6aebf..62e9213f700 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" + nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" ) @@ -78,3 +79,45 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID func (am *DefaultAccountManager) GetValidatedPeers(account *types.Account) (map[string]struct{}, error) { return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) } + +type MocIntegratedValidator struct { + ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) +} + +func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { + return nil +} + +func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) { + if a.ValidatePeerFunc != nil { + return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings) + } + return update, false, nil +} +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { + validatedPeers := make(map[string]struct{}) + for _, peer := range peers { + validatedPeers[peer.ID] = struct{}{} + } + return validatedPeers, nil +} + +func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { + return peer +} + +func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) { + return false, false, nil +} + +func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { + return nil +} + +func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { + // just a dummy +} + +func (MocIntegratedValidator) Stop(_ context.Context) { + // just a dummy +} diff --git a/management/server/jwtclaims/jwtValidator.go b/management/server/jwtclaims/jwtValidator.go index b91616fa569..79e59e76feb 100644 --- a/management/server/jwtclaims/jwtValidator.go +++ b/management/server/jwtclaims/jwtValidator.go @@ -72,15 +72,19 @@ type JSONWebKey struct { X5c []string `json:"x5c"` } -// JWTValidator struct to handle token validation and parsing -type JWTValidator struct { +type JWTValidator interface { + ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) +} + +// jwtValidatorImpl struct to handle token validation and parsing +type jwtValidatorImpl struct { options Options } var keyNotFound = errors.New("unable to find appropriate key") // NewJWTValidator constructor -func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) { +func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (JWTValidator, error) { keys, err := getPemKeys(ctx, keysLocation) if err != nil { return nil, err @@ -146,13 +150,13 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, options.UserProperty = "user" } - return &JWTValidator{ + return &jwtValidatorImpl{ options: options, }, nil } // ValidateAndParse validates the token and returns the parsed token -func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { +func (m *jwtValidatorImpl) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { // If the token is empty... if token == "" { // Check if it was required @@ -318,3 +322,28 @@ func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int { return 0 } + +type JwtValidatorMock struct{} + +func (j *JwtValidatorMock) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { + claimMaps := jwt.MapClaims{} + + switch token { + case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId": + claimMaps[UserIDClaim] = token + claimMaps[AccountIDSuffix] = "testAccountId" + claimMaps[DomainIDSuffix] = "test.com" + claimMaps[DomainCategorySuffix] = "private" + case "otherUserId": + claimMaps[UserIDClaim] = "otherUserId" + claimMaps[AccountIDSuffix] = "otherAccountId" + claimMaps[DomainIDSuffix] = "other.com" + claimMaps[DomainCategorySuffix] = "private" + case "invalidToken": + return nil, errors.New("invalid token") + } + + jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) + return jwtToken, nil +} + diff --git a/management/server/management_test.go b/management/server/management_test.go index 40514ae14db..cfa2c138f37 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -21,13 +21,10 @@ import ( "github.com/netbirdio/netbird/encryption" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/util" ) @@ -448,43 +445,6 @@ var _ = Describe("Management service", func() { }) }) -type MocIntegratedValidator struct { -} - -func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { - return nil -} - -func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) { - return update, false, nil -} - -func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { - validatedPeers := make(map[string]struct{}) - for p := range peers { - validatedPeers[p] = struct{}{} - } - return validatedPeers, nil -} - -func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { - return peer -} - -func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) { - return false, false, nil -} - -func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { - return nil -} - -func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { - -} - -func (MocIntegratedValidator) Stop(_ context.Context) {} - func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse { defer GinkgoRecover() @@ -547,7 +507,7 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc. log.Fatalf("failed creating metrics: %v", err) } - accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics) + accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, server.MocIntegratedValidator{}, metrics) if err != nil { log.Fatalf("failed creating a manager: %v", err) } diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go index 4a7b3db775c..51205f1e9b0 100644 --- a/management/server/networks/manager.go +++ b/management/server/networks/manager.go @@ -32,6 +32,9 @@ type managerImpl struct { routersManager routers.Manager } +type mockManager struct { +} + func NewManager(store store.Store, permissionsManager permissions.Manager, resourceManager resources.Manager, routersManager routers.Manager, accountManager s.AccountManager) Manager { return &managerImpl{ store: store, @@ -185,3 +188,27 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw return nil } + +func NewManagerMock() Manager { + return &mockManager{} +} + +func (m *mockManager) GetAllNetworks(ctx context.Context, accountID, userID string) ([]*types.Network, error) { + return []*types.Network{}, nil +} + +func (m *mockManager) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { + return network, nil +} + +func (m *mockManager) GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) { + return &types.Network{}, nil +} + +func (m *mockManager) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { + return network, nil +} + +func (m *mockManager) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error { + return nil +} diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 0fff5bcf8e9..02b46294785 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -34,6 +34,9 @@ type managerImpl struct { accountManager s.AccountManager } +type mockManager struct { +} + func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager s.AccountManager) Manager { return &managerImpl{ store: store, @@ -381,3 +384,39 @@ func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transacti return eventsToStore, nil } + +func NewManagerMock() Manager { + return &mockManager{} +} + +func (m *mockManager) GetAllResourcesInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkResource, error) { + return []*types.NetworkResource{}, nil +} + +func (m *mockManager) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) { + return []*types.NetworkResource{}, nil +} + +func (m *mockManager) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) { + return map[string][]string{}, nil +} + +func (m *mockManager) CreateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) { + return &types.NetworkResource{}, nil +} + +func (m *mockManager) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) { + return &types.NetworkResource{}, nil +} + +func (m *mockManager) UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) { + return &types.NetworkResource{}, nil +} + +func (m *mockManager) DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error { + return nil +} + +func (m *mockManager) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) ([]func(), error) { + return []func(){}, nil +} diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 9a4a1efb853..f2f1aad4564 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -75,7 +75,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil { - return err + return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err) } setupKey, plainKey = types.GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) @@ -132,7 +132,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil { - return err + return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err) } oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyToSave.Id)