diff --git a/casbin/async_internal_enforcer.py b/casbin/async_internal_enforcer.py index f2bfc5f..c1d0547 100644 --- a/casbin/async_internal_enforcer.py +++ b/casbin/async_internal_enforcer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +import inspect from casbin.core_enforcer import CoreEnforcer from casbin.model import Model, FunctionMap @@ -105,8 +106,12 @@ async def save_policy(self): await self.adapter.save_policy(self.model) if self.watcher: - if callable(getattr(self.watcher, "update_for_save_policy", None)): - self.watcher.update_for_save_policy(self.model) + update_for_save_policy = getattr(self.watcher, "update_for_save_policy", None) + if callable(update_for_save_policy): + if inspect.iscoroutinefunction(update_for_save_policy): + await update_for_save_policy(self.model) + else: + update_for_save_policy(self.model) else: self.watcher.update() @@ -122,8 +127,12 @@ async def _add_policy(self, sec, ptype, rule): return False if self.watcher and self.auto_notify_watcher: - if callable(getattr(self.watcher, "update_for_add_policy", None)): - self.watcher.update_for_add_policy(sec, ptype, rule) + update_for_add_policy = getattr(self.watcher, "update_for_add_policy", None) + if callable(update_for_add_policy): + if inspect.iscoroutinefunction(update_for_add_policy): + await update_for_add_policy(sec, ptype, rule) + else: + update_for_add_policy(sec, ptype, rule) else: self.watcher.update() @@ -144,8 +153,12 @@ async def _add_policies(self, sec, ptype, rules): return False if self.watcher and self.auto_notify_watcher: - if callable(getattr(self.watcher, "update_for_add_policies", None)): - self.watcher.update_for_add_policies(sec, ptype, rules) + update_for_add_policies = getattr(self.watcher, "update_for_add_policies", None) + if callable(update_for_add_policies): + if inspect.iscoroutinefunction(update_for_add_policies): + await update_for_add_policies(sec, ptype, rules) + else: + update_for_add_policies(sec, ptype, rules) else: self.watcher.update() @@ -224,8 +237,12 @@ async def _remove_policy(self, sec, ptype, rule): return False if self.watcher and self.auto_notify_watcher: - if callable(getattr(self.watcher, "update_for_remove_policy", None)): - self.watcher.update_for_remove_policy(sec, ptype, rule) + update_for_remove_policy = getattr(self.watcher, "update_for_remove_policy", None) + if callable(update_for_remove_policy): + if inspect.iscoroutinefunction(update_for_remove_policy): + await update_for_remove_policy(sec, ptype, rule) + else: + update_for_remove_policy(sec, ptype, rule) else: self.watcher.update() @@ -246,8 +263,12 @@ async def _remove_policies(self, sec, ptype, rules): return False if self.watcher and self.auto_notify_watcher: - if callable(getattr(self.watcher, "update_for_remove_policies", None)): - self.watcher.update_for_remove_policies(sec, ptype, rules) + update_for_remove_policies = getattr(self.watcher, "update_for_remove_policies", None) + if callable(update_for_remove_policies): + if inspect.iscoroutinefunction(update_for_remove_policies): + await update_for_remove_policies(sec, ptype, rules) + else: + update_for_remove_policies(sec, ptype, rules) else: self.watcher.update() @@ -265,8 +286,12 @@ async def _remove_filtered_policy(self, sec, ptype, field_index, *field_values): return False if self.watcher and self.auto_notify_watcher: - if callable(getattr(self.watcher, "update_for_remove_filtered_policy", None)): - self.watcher.update_for_remove_filtered_policy(sec, ptype, field_index, *field_values) + update_for_remove_filtered_policy = getattr(self.watcher, "update_for_remove_filtered_policy", None) + if callable(update_for_remove_filtered_policy): + if inspect.iscoroutinefunction(update_for_remove_filtered_policy): + await update_for_remove_filtered_policy(sec, ptype, field_index, *field_values) + else: + update_for_remove_filtered_policy(sec, ptype, field_index, *field_values) else: self.watcher.update() diff --git a/tests/test_watcher_ex.py b/tests/test_watcher_ex.py index e98a12f..3684087 100644 --- a/tests/test_watcher_ex.py +++ b/tests/test_watcher_ex.py @@ -14,6 +14,7 @@ import casbin from tests.test_enforcer import get_examples, TestCaseBase +from unittest import IsolatedAsyncioTestCase class SampleWatcher: @@ -113,6 +114,103 @@ def start_watch(self): pass +class AsyncSampleWatcher: + def __init__(self): + self.callback = None + self.notify_message = None + + async def close(self): + pass + + async def set_update_callback(self, callback): + """ + sets the callback function to be called when the policy is updated + :param callable callback: callback(event) + - event: event received from the rabbitmq + :return: + """ + self.callback = callback + + async def update(self, msg): + """ + update the policy + """ + self.notify_message = msg + return True + + async def update_for_add_policy(self, section, ptype, *params): + """ + update for add policy + :param section: section + :param ptype: policy type + :param params: other params + :return: True if updated + """ + message = "called add policy" + return await self.update(message) + + async def update_for_remove_policy(self, section, ptype, *params): + """ + update for remove policy + :param section: section + :param ptype: policy type + :param params: other params + :return: True if updated + """ + message = "called remove policy" + return await self.update(message) + + async def update_for_remove_filtered_policy(self, section, ptype, field_index, *params): + """ + update for remove filtered policy + :param section: section + :param ptype: policy type + :param field_index: field index + :param params: other params + :return: + """ + message = "called remove filtered policy" + return await self.update(message) + + async def update_for_save_policy(self, model: casbin.Model): + """ + update for save policy + :param model: casbin model + :return: + """ + message = "called save policy" + return await self.update(message) + + async def update_for_add_policies(self, section, ptype, *params): + """ + update for add policies + :param section: section + :param ptype: policy type + :param params: other params + :return: + """ + message = "called add policies" + return await self.update(message) + + async def update_for_remove_policies(self, section, ptype, *params): + """ + update for remove policies + :param section: section + :param ptype: policy type + :param params: other params + :return: + """ + message = "called remove policies" + return await self.update(message) + + async def start_watch(self): + """ + starts the watch thread + :return: + """ + pass + + class TestWatcherEx(TestCaseBase): def get_enforcer(self, model=None, adapter=None): return casbin.Enforcer( @@ -187,3 +285,83 @@ def test_auto_notify_disabled(self): e.remove_policies(rules) self.assertEqual(w.notify_message, None) + + +class TestAsyncWatcherEx(IsolatedAsyncioTestCase): + def get_enforcer(self, model=None, adapter=None): + return casbin.AsyncEnforcer( + model, + adapter, + ) + + async def test_auto_notify_enabled(self): + e = self.get_enforcer( + get_examples("basic_model.conf"), + get_examples("basic_policy.csv"), + ) + await e.load_policy() + + w = AsyncSampleWatcher() + e.set_watcher(w) + e.enable_auto_notify_watcher(True) + + await e.save_policy() + self.assertEqual(w.notify_message, "called save policy") + + await e.add_policy("admin", "data1", "read") + self.assertEqual(w.notify_message, "called add policy") + + await e.remove_policy("admin", "data1", "read") + self.assertEqual(w.notify_message, "called remove policy") + + await e.remove_filtered_policy(1, "data1") + self.assertEqual(w.notify_message, "called remove filtered policy") + + rules = [ + ["jack", "data4", "read"], + ["katy", "data4", "write"], + ["leyo", "data4", "read"], + ["ham", "data4", "write"], + ] + await e.add_policies(rules) + self.assertEqual(w.notify_message, "called add policies") + + await e.remove_policies(rules) + self.assertEqual(w.notify_message, "called remove policies") + + async def test_auto_notify_disabled(self): + e = self.get_enforcer( + get_examples("basic_model.conf"), + get_examples("basic_policy.csv"), + ) + await e.load_policy() + + w = SampleWatcher() + e.set_watcher(w) + e.enable_auto_notify_watcher(False) + + await e.save_policy() + self.assertEqual(w.notify_message, "called save policy") + + w.notify_message = None + + await e.add_policy("admin", "data1", "read") + self.assertEqual(w.notify_message, None) + + await e.remove_policy("admin", "data1", "read") + self.assertEqual(w.notify_message, None) + + await e.remove_filtered_policy(1, "data1") + self.assertEqual(w.notify_message, None) + + rules = [ + ["jack", "data4", "read"], + ["katy", "data4", "write"], + ["leyo", "data4", "read"], + ["ham", "data4", "write"], + ] + await e.add_policies(rules) + self.assertEqual(w.notify_message, None) + + await e.remove_policies(rules) + self.assertEqual(w.notify_message, None)