Skip to content

Commit

Permalink
feat: add update_policy()
Browse files Browse the repository at this point in the history
Signed-off-by: Zxilly <[email protected]>

feat: further development on update_policy()

Signed-off-by: Zxilly <[email protected]>

refactor: adjust to Python-style variable naming

Signed-off-by: Zxilly <[email protected]>

feat: add update_policies()

Signed-off-by: Zxilly <[email protected]>

feat: add unittest for update_policies()

Signed-off-by: Zxilly <[email protected]>

refactor: remove adapter check

Signed-off-by: Zxilly <[email protected]>

refactor: remove duplicated check

Signed-off-by: Zxilly <[email protected]>
  • Loading branch information
Zxilly committed Feb 3, 2021
1 parent 5676034 commit 7f7d26f
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 1 deletion.
34 changes: 34 additions & 0 deletions casbin/internal_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,40 @@ def _add_policies(self,sec,ptype,rules):
self.watcher.update()

return rules_added

def _update_policy(self, sec, ptype, old_rule, new_rule):
"""updates a rule from the current policy."""
rule_updated = self.model.update_policy(sec, ptype, old_rule, new_rule)

if not rule_updated:
return rule_updated

if self.adapter and self.auto_save:

if self.adapter.update_policy(sec, ptype, old_rule, new_rule) is False:
return False

if self.watcher:
self.watcher.update()

return rule_updated

def _update_policies(self, sec, ptype, old_rules, new_rules):
"""updates rules from the current policy."""
rules_updated = self.model.update_policies(sec, ptype, old_rules, new_rules)

if not rules_updated:
return rules_updated

if self.adapter and self.auto_save:

if self.adapter.update_policies(sec, ptype, old_rules, new_rules) is False:
return False

if self.watcher:
self.watcher.update()

return rules_updated

def _remove_policy(self, sec, ptype, rule):
"""removes a rule from the current policy."""
Expand Down
16 changes: 16 additions & 0 deletions casbin/management_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,22 @@ def add_named_policies(self,ptype,rules):
Otherwise the function returns true for the corresponding by adding the new rule."""
return self._add_policies('p',ptype,rules)

def update_policy(self, old_rule, new_rule):
"""updates an authorization rule from the current policy."""
return self.update_named_policy('p', old_rule, new_rule)

def update_policies(self, old_rules, new_rules):
"""updates authorization rules from the current policy."""
return self.update_named_policies('p', old_rules, new_rules)

def update_named_policy(self, ptype, old_rule, new_rule):
"""updates an authorization rule from the current named policy."""
return self._update_policy('p', ptype, old_rule, new_rule)

def update_named_policies(self, ptype, old_rules, new_rules):
"""updates authorization rules from the current named policy."""
return self._update_policies('p', ptype, old_rules, new_rules)

def remove_policy(self, *params):
"""removes an authorization rule from the current policy."""
return self.remove_named_policy('p', *params)
Expand Down
20 changes: 19 additions & 1 deletion casbin/model/policy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from casbin import util
import logging

from casbin import util


class Policy:
def __init__(self):
Expand Down Expand Up @@ -88,6 +89,23 @@ def add_policies(self,sec,ptype,rules):

return True

def update_policy(self, sec, ptype, old_rule, new_rule):
"""update a policy rule from the model."""

if not self.has_policy(sec, ptype, old_rule):
return False

return self.remove_policy(sec, ptype, old_rule) and self.add_policy(sec, ptype, new_rule)

def update_policies(self, sec, ptype, old_rules, new_rules):
"""update policy rules from the model."""

for rule in old_rules:
if not self.has_policy(sec, ptype, rule):
return False

return self.remove_policies(sec, ptype, old_rules) and self.add_policies(sec, ptype, new_rules)

def remove_policy(self, sec, ptype, rule):
"""removes a policy rule from the model."""

Expand Down
38 changes: 38 additions & 0 deletions tests/model/test_policy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from unittest import TestCase

from casbin.model import Model
from tests.test_enforcer import get_examples

Expand Down Expand Up @@ -53,6 +54,43 @@ def test_add_role_policy(self):
self.assertTrue(m.get_policy('p', 'p') == [p_rule1, p_rule2])
self.assertTrue(m.get_policy('g', 'g') == [g_rule])

def test_update_policy(self):
m = Model()
m.load_model(get_examples("basic_model.conf"))

old_rule = ['admin', 'domain1', 'data1', 'read']
new_rule = ['admin', 'domain1', 'data2', 'read']

m.add_policy('p', 'p', old_rule)
self.assertTrue(m.has_policy('p', 'p', old_rule))

m.update_policy('p', 'p', old_rule, new_rule)
self.assertFalse(m.has_policy('p', 'p', old_rule))
self.assertTrue(m.has_policy('p', 'p', new_rule))

def test_update_policies(self):
m = Model()
m.load_model(get_examples("basic_model.conf"))

old_rules = [['admin', 'domain1', 'data1', 'read'],
['admin', 'domain1', 'data2', 'read'],
['admin', 'domain1', 'data3', 'read']]
new_rules = [['admin', 'domain1', 'data4', 'read'],
['admin', 'domain1', 'data5', 'read'],
['admin', 'domain1', 'data6', 'read']]

m.add_policies('p', 'p', old_rules)

for old_rule in old_rules:
self.assertTrue(m.has_policy('p', 'p', old_rule))

m.update_policies('p', 'p', old_rules, new_rules)

for old_rule in old_rules:
self.assertFalse(m.has_policy('p', 'p', old_rule))
for new_rule in new_rules:
self.assertTrue(m.has_policy('p', 'p', new_rule))

def test_remove_policy(self):
m = Model()
m.load_model(get_examples("basic_model.conf"))
Expand Down

0 comments on commit 7f7d26f

Please sign in to comment.