From 4b8680c409b5ab2d118ca81142694228434b2aeb Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 19 Sep 2024 20:07:02 +0300 Subject: [PATCH] fix merge Signed-off-by: bcmmbaga --- .../server/http/dns_settings_handler.go | 7 ++- management/server/http/groups_handler.go | 40 ++++++------ management/server/http/groups_handler_test.go | 12 +++- management/server/http/nameservers_handler.go | 24 +++++--- .../server/http/nameservers_handler_test.go | 3 + management/server/http/pat_handler.go | 34 ++++++++--- management/server/http/pat_handler_test.go | 5 +- management/server/http/policies_handler.go | 61 +++++++++---------- .../server/http/policies_handler_test.go | 8 +-- .../server/http/posture_checks_handler.go | 34 ++++++----- management/server/http/routes_handler.go | 20 +++--- management/server/http/routes_handler_test.go | 11 ++++ management/server/http/setupkeys_handler.go | 31 +++++++--- 13 files changed, 178 insertions(+), 112 deletions(-) diff --git a/management/server/http/dns_settings_handler.go b/management/server/http/dns_settings_handler.go index 74b0e1a55ad..997077330be 100644 --- a/management/server/http/dns_settings_handler.go +++ b/management/server/http/dns_settings_handler.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" + "github.com/netbirdio/netbird/management/server/status" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" @@ -39,12 +40,12 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque return } - dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id) - if err != nil { - util.WriteError(r.Context(), err, w) + if !(user.HasAdminPower() || user.IsServiceUser) { + util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings"), w) return } + dnsSettings := account.DNSSettings.Copy() apiDNSSettings := &api.DNSSettings{ DisabledManagementGroups: dnsSettings.DisabledManagementGroups, } diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go index c622d873af7..f5e1e499351 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/groups_handler.go @@ -42,12 +42,16 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { return } - groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id) - if err != nil { - util.WriteError(r.Context(), err, w) + if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { + util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "groups are blocked for users"), w) return } + groups := make([]*nbgroup.Group, 0, len(account.Groups)) + for _, item := range account.Groups { + groups = append(groups, item) + } + groupsResponse := make([]*api.Group, 0, len(groups)) for _, group := range groups { groupsResponse = append(groupsResponse, toGroupResponse(account, group)) @@ -219,25 +223,25 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { return } - switch r.Method { - case http.MethodGet: - groupID := mux.Vars(r)["groupId"] - if len(groupID) == 0 { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) - return - } + groupID := mux.Vars(r)["groupId"] + if len(groupID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) + return + } - group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } + if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { + util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "groups are blocked for users"), w) + return + } - util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group)) - default: - util.WriteError(r.Context(), status.Errorf(status.NotFound, "HTTP method not found"), w) + group, ok := account.Groups[groupID] + if !ok { + util.WriteError(r.Context(), status.Errorf(status.NotFound, "group with ID %s not found", groupID), w) return } + + util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group)) + } func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group { diff --git a/management/server/http/groups_handler_test.go b/management/server/http/groups_handler_test.go index d5ed07c9ef3..9396980080b 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/groups_handler_test.go @@ -30,7 +30,7 @@ var TestPeers = map[string]*nbpeer.Peer{ "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, } -func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { +func initGroupTestData(user *server.User, groups ...*nbgroup.Group) *GroupsHandler { return &GroupsHandler{ accountManager: &mock_server.MockAccountManager{ SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error { @@ -57,7 +57,7 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { }, nil }, GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return &server.Account{ + account := &server.Account{ Id: claims.AccountId, Domain: "hotmail.com", Peers: TestPeers, @@ -69,7 +69,13 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI}, "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, }, - }, user, nil + } + + for _, group := range groups { + account.Groups[group.ID] = group + } + + return account, user, nil }, DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error { if groupID == "linked-grp" { diff --git a/management/server/http/nameservers_handler.go b/management/server/http/nameservers_handler.go index c6e00bb2d7f..271a41864ca 100644 --- a/management/server/http/nameservers_handler.go +++ b/management/server/http/nameservers_handler.go @@ -43,12 +43,16 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re return } - nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id) - if err != nil { - util.WriteError(r.Context(), err, w) + if !(user.HasAdminPower() || user.IsServiceUser) { + util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups"), w) return } + nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups)) + for _, item := range account.NameServerGroups { + nsGroups = append(nsGroups, item.Copy()) + } + apiNameservers := make([]*api.NameserverGroup, 0) for _, r := range nsGroups { apiNameservers = append(apiNameservers, toNameserverGroupResponse(r)) @@ -181,14 +185,20 @@ func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.R return } - nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID) - if err != nil { - util.WriteError(r.Context(), err, w) + if !(user.HasAdminPower() || user.IsServiceUser) { + util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power can view nameserver groups"), w) return } - resp := toNameserverGroupResponse(nsGroup) + nsGroup, ok := account.NameServerGroups[nsGroupID] + if !ok { + util.WriteError(r.Context(), status.Errorf(status.NotFound, + "nameserver group with ID %s not found", nsGroupID, + ), w) + return + } + resp := toNameserverGroupResponse(nsGroup.Copy()) util.WriteJSONObject(r.Context(), w, &resp) } diff --git a/management/server/http/nameservers_handler_test.go b/management/server/http/nameservers_handler_test.go index 28b080571e1..f11dac5ca12 100644 --- a/management/server/http/nameservers_handler_test.go +++ b/management/server/http/nameservers_handler_test.go @@ -35,6 +35,9 @@ var testingNSAccount = &server.Account{ Users: map[string]*server.User{ "test_user": server.NewAdminUser("test_user"), }, + NameServerGroups: map[string]*nbdns.NameServerGroup{ + existingNSGroupID: baseExistingNSGroup, + }, } var baseExistingNSGroup = &nbdns.NameServerGroup{ diff --git a/management/server/http/pat_handler.go b/management/server/http/pat_handler.go index 9d8448d3dea..9a2ae376523 100644 --- a/management/server/http/pat_handler.go +++ b/management/server/http/pat_handler.go @@ -41,20 +41,25 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { } vars := mux.Vars(r) - userID := vars["userId"] - if len(userID) == 0 { + targetUserID := vars["userId"] + if len(targetUserID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } - pats, err := h.accountManager.GetAllPATs(r.Context(), account.Id, user.Id, userID) - if err != nil { - util.WriteError(r.Context(), err, w) + targetUser, ok := account.Users[targetUserID] + if !ok { + util.WriteError(r.Context(), status.Errorf(status.NotFound, "user not found"), w) + return + } + + if user.Id != targetUserID || !user.HasAdminPower() && !targetUser.IsServiceUser { + util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "no permission to get PAT for this user"), w) return } var patResponse []*api.PersonalAccessToken - for _, pat := range pats { + for _, pat := range targetUser.PATs { patResponse = append(patResponse, toPATResponse(pat)) } @@ -83,9 +88,20 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { return } - pat, err := h.accountManager.GetPAT(r.Context(), account.Id, user.Id, targetUserID, tokenID) - if err != nil { - util.WriteError(r.Context(), err, w) + targetUser, ok := account.Users[targetUserID] + if !ok { + util.WriteError(r.Context(), status.Errorf(status.NotFound, "user not found"), w) + return + } + + if user.Id != targetUserID || !user.HasAdminPower() && !targetUser.IsServiceUser { + util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "no permission to get PAT for this user"), w) + return + } + + pat, ok := targetUser.PATs[tokenID] + if !ok { + util.WriteError(r.Context(), status.Errorf(status.NotFound, "PAT not found"), w) return } diff --git a/management/server/http/pat_handler_test.go b/management/server/http/pat_handler_test.go index b72f71468df..2e777d567a0 100644 --- a/management/server/http/pat_handler_test.go +++ b/management/server/http/pat_handler_test.go @@ -36,7 +36,8 @@ var testAccount = &server.Account{ Domain: testDomain, Users: map[string]*server.User{ existingUserID: { - Id: existingUserID, + Id: existingUserID, + Role: server.UserRoleAdmin, PATs: map[string]*server.PersonalAccessToken{ existingTokenID: { ID: existingTokenID, @@ -228,7 +229,7 @@ func TestTokenHandlers(t *testing.T) { if err = json.Unmarshal(content, &got); err != nil { t.Fatalf("Sent content is not in correct json format; %v", err) } - assert.True(t, cmp.Equal(got, expectedTokens)) + assert.ElementsMatch(t, expectedTokens, got) case "Get Existing Token": expectedToken := toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID]) got := &api.PersonalAccessToken{} diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index 9622668f49f..cbceaf17cb5 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -3,6 +3,7 @@ package http import ( "encoding/json" "net/http" + "slices" "strconv" "github.com/gorilla/mux" @@ -41,14 +42,13 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { return } - accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id) - if err != nil { - util.WriteError(r.Context(), err, w) + if !user.HasAdminPower() && !user.IsServiceUser { + util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power can view policies"), w) return } - policies := []*api.Policy{} - for _, policy := range accountPolicies { + policies := make([]*api.Policy, 0, len(account.Policies)) + for _, policy := range account.Policies { resp := toPolicyResponse(account, policy) if len(resp.Rules) == 0 { util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) @@ -76,13 +76,9 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { return } - policyIdx := -1 - for i, policy := range account.Policies { - if policy.ID == policyID { - policyIdx = i - break - } - } + policyIdx := slices.IndexFunc(account.Policies, func(policy *server.Policy) bool { + return policy.ID == policyID + }) if policyIdx < 0 { util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w) return @@ -258,31 +254,32 @@ func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) { return } - switch r.Method { - case http.MethodGet: - vars := mux.Vars(r) - policyID := vars["policyId"] - if len(policyID) == 0 { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w) - return - } + vars := mux.Vars(r) + policyID := vars["policyId"] + if len(policyID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w) + return + } - policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } + if !user.HasAdminPower() && !user.IsServiceUser { + util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power can view policies"), w) + return + } - resp := toPolicyResponse(account, policy) - if len(resp.Rules) == 0 { - util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) + for _, policy := range account.Policies { + if policy.ID == policyID { + resp := toPolicyResponse(account, policy) + if len(resp.Rules) == 0 { + util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) + return + } + + util.WriteJSONObject(r.Context(), w, resp) return } - - util.WriteJSONObject(r.Context(), w, resp) - default: - util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w) } + + util.WriteError(r.Context(), status.Errorf(status.NotFound, "policy with ID %s not found", policyID), w) } func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Policy { diff --git a/management/server/http/policies_handler_test.go b/management/server/http/policies_handler_test.go index 06274fb072d..60fa39f069f 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/policies_handler_test.go @@ -48,11 +48,9 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { user := server.NewAdminUser("test_user") return &server.Account{ - Id: claims.AccountId, - Domain: "hotmail.com", - Policies: []*server.Policy{ - {ID: "id-existed"}, - }, + Id: claims.AccountId, + Domain: "hotmail.com", + Policies: policies, Groups: map[string]*nbgroup.Group{ "F": {ID: "F"}, "G": {ID: "G"}, diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/posture_checks_handler.go index 059cb3b8055..ed498849cf6 100644 --- a/management/server/http/posture_checks_handler.go +++ b/management/server/http/posture_checks_handler.go @@ -3,6 +3,7 @@ package http import ( "encoding/json" "net/http" + "slices" "github.com/gorilla/mux" @@ -43,14 +44,13 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt return } - accountPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), account.Id, user.Id) - if err != nil { - util.WriteError(r.Context(), err, w) + if !user.HasAdminPower() { + util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power can view posture checks"), w) return } - postureChecks := []*api.PostureCheck{} - for _, postureCheck := range accountPostureChecks { + postureChecks := make([]*api.PostureCheck, 0, len(account.PostureChecks)) + for _, postureCheck := range account.PostureChecks { postureChecks = append(postureChecks, postureCheck.ToAPIResponse()) } @@ -73,13 +73,9 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http return } - postureChecksIdx := -1 - for i, postureCheck := range account.PostureChecks { - if postureCheck.ID == postureChecksID { - postureChecksIdx = i - break - } - } + postureChecksIdx := slices.IndexFunc(account.PostureChecks, func(postureChecks *posture.Checks) bool { + return postureChecks.ID == postureChecksID + }) if postureChecksIdx < 0 { util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w) return @@ -116,13 +112,19 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re return } - postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), account.Id, postureChecksID, user.Id) - if err != nil { - util.WriteError(r.Context(), err, w) + if !user.HasAdminPower() { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "only users with admin power can view posture checks"), w) return } - util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse()) + for _, postureChecks := range account.PostureChecks { + if postureChecks.ID == postureChecksID { + util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse()) + return + } + } + + util.WriteError(r.Context(), status.Errorf(status.NotFound, "posture checks with ID %s not found", postureChecksID), w) } // DeletePostureCheck handles posture check deletion request diff --git a/management/server/http/routes_handler.go b/management/server/http/routes_handler.go index 18c347334ed..184c40a51a5 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/routes_handler.go @@ -10,7 +10,6 @@ import ( "unicode/utf8" "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" @@ -49,13 +48,13 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { return } - routes, err := h.accountManager.ListRoutes(r.Context(), account.Id, user.Id) - if err != nil { - util.WriteError(r.Context(), err, w) + if !user.HasAdminPower() && !user.IsServiceUser { + util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes"), w) return } + apiRoutes := make([]*api.Route, 0) - for _, route := range routes { + for _, route := range account.Routes { route, err := toRouteResponse(route) if err != nil { util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w) @@ -301,9 +300,14 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { return } - foundRoute, err := h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id) - if err != nil { - util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w) + if !user.HasAdminPower() && !user.IsServiceUser { + util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes"), w) + return + } + + foundRoute, ok := account.Routes[route.ID(routeID)] + if !ok { + util.WriteError(r.Context(), status.Errorf(status.NotFound, "route with ID %s not found", routeID), w) return } diff --git a/management/server/http/routes_handler_test.go b/management/server/http/routes_handler_test.go index 40075eb9d8b..d732d5ecf32 100644 --- a/management/server/http/routes_handler_test.go +++ b/management/server/http/routes_handler_test.go @@ -140,6 +140,17 @@ func initRoutesTestData() *RoutesHandler { return nil }, GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + route2 := baseExistingRoute.Copy() + route2.PeerGroups = []string{existingGroupID} + + route3 := baseExistingRoute.Copy() + route3.Domains = domain.List{existingDomain} + + testingAccount.Routes = map[route.ID]*route.Route{ + existingRouteID: baseExistingRoute.Copy(), + existingRouteID2: route2, + existingRouteID3: route3, + } return testingAccount, testingAccount.Users["test_user"], nil }, }, diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/setupkeys_handler.go index 8ee7dfabaa7..92a52634d94 100644 --- a/management/server/http/setupkeys_handler.go +++ b/management/server/http/setupkeys_handler.go @@ -102,13 +102,27 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { return } - key, err := h.accountManager.GetSetupKey(r.Context(), account.Id, user.Id, keyID) - if err != nil { - util.WriteError(r.Context(), err, w) + if !user.HasAdminPower() && !user.IsServiceUser { + util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "only users with admin power can view setup keys"), w) return } - writeSuccess(r.Context(), w, key) + for _, key := range account.SetupKeys { + if key.Id == keyID { + foundKey := key.Copy() + + // the UpdatedAt field was introduced later, + // so there might be that some keys have a Zero value (e.g, null in the store file) + if foundKey.UpdatedAt.IsZero() { + foundKey.UpdatedAt = foundKey.CreatedAt + } + + writeSuccess(r.Context(), w, key) + return + } + } + + util.WriteError(r.Context(), status.Errorf(status.NotFound, "setup key not found"), w) } // UpdateSetupKey is a PUT request to update server.SetupKey @@ -167,15 +181,14 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques return } - setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), account.Id, user.Id) - if err != nil { - util.WriteError(r.Context(), err, w) + if !user.HasAdminPower() && !user.IsServiceUser { + util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "only users with admin power can view setup keys"), w) return } apiSetupKeys := make([]*api.SetupKey, 0) - for _, key := range setupKeys { - apiSetupKeys = append(apiSetupKeys, toResponseBody(key)) + for _, key := range account.SetupKeys { + apiSetupKeys = append(apiSetupKeys, toResponseBody(key.Copy())) } util.WriteJSONObject(r.Context(), w, apiSetupKeys)