Skip to content

Commit

Permalink
Fix rule update handler and tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
gigovich committed Mar 7, 2023
1 parent 53886c3 commit 0155f58
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 36 deletions.
1 change: 1 addition & 0 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -1259,6 +1259,7 @@ func addAllGroup(account *Account) error {
}
account.Rules = map[string]*Rule{defaultRule.ID: defaultRule}

// TODO: after migration we need to drop rule and create policy directly
defaultPolicy, err := RuleToPolicy(defaultRule)
if err != nil {
return fmt.Errorf("convert rule to policy: %w", err)
Expand Down
35 changes: 20 additions & 15 deletions management/server/http/rules_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ func (h *RulesHandler) UpdateRule(w http.ResponseWriter, r *http.Request) {
return
}

_, ok := account.Rules[ruleID]
if !ok {
policy, err := h.accountManager.GetPolicy(account.Id, ruleID, user.Id)
if err != nil {
util.WriteError(status.Errorf(status.NotFound, "couldn't find rule id %s", ruleID), w)
return
}
Expand All @@ -98,28 +98,33 @@ func (h *RulesHandler) UpdateRule(w http.ResponseWriter, r *http.Request) {
reqDestinations = *req.Destinations
}

rule := server.Rule{
ID: ruleID,
Name: req.Name,
Source: reqSources,
Destination: reqDestinations,
Disabled: req.Disabled,
Description: req.Description,
if len(policy.Rules) != 1 {
util.WriteError(status.Errorf(status.Internal, "policy should contain exactly one rule"), w)
return
}

policy.Name = req.Name
policy.Description = req.Description
policy.Enabled = !req.Disabled
policy.Rules[0].ID = ruleID
policy.Rules[0].Name = req.Name
policy.Rules[0].Sources = reqSources
policy.Rules[0].Destinations = reqDestinations
policy.Rules[0].Enabled = !req.Disabled
policy.Rules[0].Description = req.Description
if err := policy.UpdateQueryFromRules(); err != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule to convert it to policy"), w)
return
}

switch req.Flow {
case server.TrafficFlowBidirectString:
rule.Flow = server.TrafficFlowBidirect
policy.Rules[0].Action = server.PolicyTrafficActionAccept
default:
util.WriteError(status.Errorf(status.InvalidArgument, "unknown flow type"), w)
return
}

policy, err := server.RuleToPolicy(&rule)
if err != nil {
util.WriteError(err, w)
return
}
err = h.accountManager.SavePolicy(account.Id, user.Id, policy)
if err != nil {
util.WriteError(err, w)
Expand Down
39 changes: 18 additions & 21 deletions management/server/http/rules_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,24 @@ import (
)

func initRulesTestData(rules ...*server.Rule) *RulesHandler {
testPolicies := make(map[string]*server.Policy, len(rules))
for _, rule := range rules {
policy, err := server.RuleToPolicy(rule)
if err != nil {
panic(err)
}
if err := policy.UpdateQueryFromRules(); err != nil {
panic(err)
}
testPolicies[policy.ID] = policy
}
return &RulesHandler{
accountManager: &mock_server.MockAccountManager{
GetPolicyFunc: func(_, policyID, _ string) (*server.Policy, error) {
if policyID != "idoftherule" {
policy, ok := testPolicies[policyID]
if !ok {
return nil, fmt.Errorf("not found")
}
policy := &server.Policy{
ID: "idoftherule",
Name: "Policy",
Enabled: true,
Description: "Description",
Rules: []*server.PolicyRule{
{
ID: "idoftherule",
Name: "Rule",
Enabled: true,
Sources: []string{"idofsrcrule"},
Destinations: []string{"idofdestrule"},
Action: server.PolicyTrafficActionAccept,
},
},
}
if err := policy.UpdateQueryFromRules(); err != nil {
return nil, err
}
return policy, nil
},
SavePolicyFunc: func(_, _ string, policy *server.Policy) error {
Expand Down Expand Up @@ -228,7 +221,11 @@ func TestRulesWriteRule(t *testing.T) {
},
}

p := initRulesTestData()
p := initRulesTestData(&server.Rule{
ID: "id-existed",
Name: "Default POSTed Rule",
Flow: server.TrafficFlowBidirect,
})

for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
Expand Down

0 comments on commit 0155f58

Please sign in to comment.