Skip to content

Commit

Permalink
fix merge
Browse files Browse the repository at this point in the history
Signed-off-by: bcmmbaga <[email protected]>
  • Loading branch information
bcmmbaga committed Sep 19, 2024
1 parent 9631cb4 commit 4b8680c
Show file tree
Hide file tree
Showing 13 changed files with 178 additions and 112 deletions.
7 changes: 4 additions & 3 deletions management/server/http/dns_settings_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"net/http"

"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/management/server"
Expand Down Expand Up @@ -39,12 +40,12 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque
return
}

dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
if !(user.HasAdminPower() || user.IsServiceUser) {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings"), w)
return
}

dnsSettings := account.DNSSettings.Copy()
apiDNSSettings := &api.DNSSettings{
DisabledManagementGroups: dnsSettings.DisabledManagementGroups,
}
Expand Down
40 changes: 22 additions & 18 deletions management/server/http/groups_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,16 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
return
}

groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "groups are blocked for users"), w)
return
}

groups := make([]*nbgroup.Group, 0, len(account.Groups))
for _, item := range account.Groups {
groups = append(groups, item)
}

groupsResponse := make([]*api.Group, 0, len(groups))
for _, group := range groups {
groupsResponse = append(groupsResponse, toGroupResponse(account, group))
Expand Down Expand Up @@ -219,25 +223,25 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
return
}

switch r.Method {
case http.MethodGet:
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}

group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "groups are blocked for users"), w)
return
}

util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group))
default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "HTTP method not found"), w)
group, ok := account.Groups[groupID]
if !ok {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "group with ID %s not found", groupID), w)
return
}

util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group))

}

func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group {
Expand Down
12 changes: 9 additions & 3 deletions management/server/http/groups_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ var TestPeers = map[string]*nbpeer.Peer{
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
}

func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
func initGroupTestData(user *server.User, groups ...*nbgroup.Group) *GroupsHandler {
return &GroupsHandler{
accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error {
Expand All @@ -57,7 +57,7 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
}, nil
},
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{
account := &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
Peers: TestPeers,
Expand All @@ -69,7 +69,13 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI},
"id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI},
},
}, user, nil
}

for _, group := range groups {
account.Groups[group.ID] = group
}

return account, user, nil
},
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
if groupID == "linked-grp" {
Expand Down
24 changes: 17 additions & 7 deletions management/server/http/nameservers_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,16 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re
return
}

nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
if !(user.HasAdminPower() || user.IsServiceUser) {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups"), w)
return
}

nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups))
for _, item := range account.NameServerGroups {
nsGroups = append(nsGroups, item.Copy())
}

apiNameservers := make([]*api.NameserverGroup, 0)
for _, r := range nsGroups {
apiNameservers = append(apiNameservers, toNameserverGroupResponse(r))
Expand Down Expand Up @@ -181,14 +185,20 @@ func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.R
return
}

nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID)
if err != nil {
util.WriteError(r.Context(), err, w)
if !(user.HasAdminPower() || user.IsServiceUser) {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power can view nameserver groups"), w)
return
}

resp := toNameserverGroupResponse(nsGroup)
nsGroup, ok := account.NameServerGroups[nsGroupID]
if !ok {
util.WriteError(r.Context(), status.Errorf(status.NotFound,
"nameserver group with ID %s not found", nsGroupID,
), w)
return
}

resp := toNameserverGroupResponse(nsGroup.Copy())
util.WriteJSONObject(r.Context(), w, &resp)
}

Expand Down
3 changes: 3 additions & 0 deletions management/server/http/nameservers_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ var testingNSAccount = &server.Account{
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
NameServerGroups: map[string]*nbdns.NameServerGroup{
existingNSGroupID: baseExistingNSGroup,
},
}

var baseExistingNSGroup = &nbdns.NameServerGroup{
Expand Down
34 changes: 25 additions & 9 deletions management/server/http/pat_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,25 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
}

vars := mux.Vars(r)
userID := vars["userId"]
if len(userID) == 0 {
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}

pats, err := h.accountManager.GetAllPATs(r.Context(), account.Id, user.Id, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
targetUser, ok := account.Users[targetUserID]
if !ok {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "user not found"), w)
return
}

if user.Id != targetUserID || !user.HasAdminPower() && !targetUser.IsServiceUser {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "no permission to get PAT for this user"), w)
return
}

var patResponse []*api.PersonalAccessToken
for _, pat := range pats {
for _, pat := range targetUser.PATs {
patResponse = append(patResponse, toPATResponse(pat))
}

Expand Down Expand Up @@ -83,9 +88,20 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
return
}

pat, err := h.accountManager.GetPAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
if err != nil {
util.WriteError(r.Context(), err, w)
targetUser, ok := account.Users[targetUserID]
if !ok {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "user not found"), w)
return
}

if user.Id != targetUserID || !user.HasAdminPower() && !targetUser.IsServiceUser {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "no permission to get PAT for this user"), w)
return
}

pat, ok := targetUser.PATs[tokenID]
if !ok {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "PAT not found"), w)
return
}

Expand Down
5 changes: 3 additions & 2 deletions management/server/http/pat_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ var testAccount = &server.Account{
Domain: testDomain,
Users: map[string]*server.User{
existingUserID: {
Id: existingUserID,
Id: existingUserID,
Role: server.UserRoleAdmin,
PATs: map[string]*server.PersonalAccessToken{
existingTokenID: {
ID: existingTokenID,
Expand Down Expand Up @@ -228,7 +229,7 @@ func TestTokenHandlers(t *testing.T) {
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.True(t, cmp.Equal(got, expectedTokens))
assert.ElementsMatch(t, expectedTokens, got)
case "Get Existing Token":
expectedToken := toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID])
got := &api.PersonalAccessToken{}
Expand Down
61 changes: 29 additions & 32 deletions management/server/http/policies_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package http
import (
"encoding/json"
"net/http"
"slices"
"strconv"

"github.com/gorilla/mux"
Expand Down Expand Up @@ -41,14 +42,13 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
return
}

accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
if !user.HasAdminPower() && !user.IsServiceUser {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power can view policies"), w)
return
}

policies := []*api.Policy{}
for _, policy := range accountPolicies {
policies := make([]*api.Policy, 0, len(account.Policies))
for _, policy := range account.Policies {
resp := toPolicyResponse(account, policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
Expand Down Expand Up @@ -76,13 +76,9 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
return
}

policyIdx := -1
for i, policy := range account.Policies {
if policy.ID == policyID {
policyIdx = i
break
}
}
policyIdx := slices.IndexFunc(account.Policies, func(policy *server.Policy) bool {
return policy.ID == policyID
})
if policyIdx < 0 {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
return
Expand Down Expand Up @@ -258,31 +254,32 @@ func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
return
}

switch r.Method {
case http.MethodGet:
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return
}
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return
}

policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if !user.HasAdminPower() && !user.IsServiceUser {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power can view policies"), w)
return
}

resp := toPolicyResponse(account, policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
for _, policy := range account.Policies {
if policy.ID == policyID {
resp := toPolicyResponse(account, policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return
}

util.WriteJSONObject(r.Context(), w, resp)
return
}

util.WriteJSONObject(r.Context(), w, resp)
default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w)
}

util.WriteError(r.Context(), status.Errorf(status.NotFound, "policy with ID %s not found", policyID), w)
}

func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Policy {
Expand Down
8 changes: 3 additions & 5 deletions management/server/http/policies_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,9 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
Policies: []*server.Policy{
{ID: "id-existed"},
},
Id: claims.AccountId,
Domain: "hotmail.com",
Policies: policies,
Groups: map[string]*nbgroup.Group{
"F": {ID: "F"},
"G": {ID: "G"},
Expand Down
Loading

0 comments on commit 4b8680c

Please sign in to comment.