From eabb1bda1a5e3f4536d079835a1bf9aa4201e935 Mon Sep 17 00:00:00 2001 From: House Date: Mon, 29 Jan 2024 20:46:41 +0200 Subject: [PATCH 1/3] feat: Added support for async watcher callbacks AsyncEnforcer can now also await Watcher callbacks if they are Coroutines. --- casbin/async_internal_enforcer.py | 49 +++++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/casbin/async_internal_enforcer.py b/casbin/async_internal_enforcer.py index f2bfc5f..c1d0547 100644 --- a/casbin/async_internal_enforcer.py +++ b/casbin/async_internal_enforcer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +import inspect from casbin.core_enforcer import CoreEnforcer from casbin.model import Model, FunctionMap @@ -105,8 +106,12 @@ async def save_policy(self): await self.adapter.save_policy(self.model) if self.watcher: - if callable(getattr(self.watcher, "update_for_save_policy", None)): - self.watcher.update_for_save_policy(self.model) + update_for_save_policy = getattr(self.watcher, "update_for_save_policy", None) + if callable(update_for_save_policy): + if inspect.iscoroutinefunction(update_for_save_policy): + await update_for_save_policy(self.model) + else: + update_for_save_policy(self.model) else: self.watcher.update() @@ -122,8 +127,12 @@ async def _add_policy(self, sec, ptype, rule): return False if self.watcher and self.auto_notify_watcher: - if callable(getattr(self.watcher, "update_for_add_policy", None)): - self.watcher.update_for_add_policy(sec, ptype, rule) + update_for_add_policy = getattr(self.watcher, "update_for_add_policy", None) + if callable(update_for_add_policy): + if inspect.iscoroutinefunction(update_for_add_policy): + await update_for_add_policy(sec, ptype, rule) + else: + update_for_add_policy(sec, ptype, rule) else: self.watcher.update() @@ -144,8 +153,12 @@ async def _add_policies(self, sec, ptype, rules): return False if self.watcher and self.auto_notify_watcher: - if callable(getattr(self.watcher, "update_for_add_policies", None)): - self.watcher.update_for_add_policies(sec, ptype, rules) + update_for_add_policies = getattr(self.watcher, "update_for_add_policies", None) + if callable(update_for_add_policies): + if inspect.iscoroutinefunction(update_for_add_policies): + await update_for_add_policies(sec, ptype, rules) + else: + update_for_add_policies(sec, ptype, rules) else: self.watcher.update() @@ -224,8 +237,12 @@ async def _remove_policy(self, sec, ptype, rule): return False if self.watcher and self.auto_notify_watcher: - if callable(getattr(self.watcher, "update_for_remove_policy", None)): - self.watcher.update_for_remove_policy(sec, ptype, rule) + update_for_remove_policy = getattr(self.watcher, "update_for_remove_policy", None) + if callable(update_for_remove_policy): + if inspect.iscoroutinefunction(update_for_remove_policy): + await update_for_remove_policy(sec, ptype, rule) + else: + update_for_remove_policy(sec, ptype, rule) else: self.watcher.update() @@ -246,8 +263,12 @@ async def _remove_policies(self, sec, ptype, rules): return False if self.watcher and self.auto_notify_watcher: - if callable(getattr(self.watcher, "update_for_remove_policies", None)): - self.watcher.update_for_remove_policies(sec, ptype, rules) + update_for_remove_policies = getattr(self.watcher, "update_for_remove_policies", None) + if callable(update_for_remove_policies): + if inspect.iscoroutinefunction(update_for_remove_policies): + await update_for_remove_policies(sec, ptype, rules) + else: + update_for_remove_policies(sec, ptype, rules) else: self.watcher.update() @@ -265,8 +286,12 @@ async def _remove_filtered_policy(self, sec, ptype, field_index, *field_values): return False if self.watcher and self.auto_notify_watcher: - if callable(getattr(self.watcher, "update_for_remove_filtered_policy", None)): - self.watcher.update_for_remove_filtered_policy(sec, ptype, field_index, *field_values) + update_for_remove_filtered_policy = getattr(self.watcher, "update_for_remove_filtered_policy", None) + if callable(update_for_remove_filtered_policy): + if inspect.iscoroutinefunction(update_for_remove_filtered_policy): + await update_for_remove_filtered_policy(sec, ptype, field_index, *field_values) + else: + update_for_remove_filtered_policy(sec, ptype, field_index, *field_values) else: self.watcher.update() From 43dba498b0d5a047e03d41210068992f324b3d42 Mon Sep 17 00:00:00 2001 From: House Date: Wed, 31 Jan 2024 22:28:13 +0200 Subject: [PATCH 2/3] feat: Added tests for async watcher I've added pytests to verify the functionality of async watcher callbacks. Note: I used the existing pytest for synchronous calls as a reference for creating this async version. --- tests/test_watcher_ex.py | 180 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 179 insertions(+), 1 deletion(-) diff --git a/tests/test_watcher_ex.py b/tests/test_watcher_ex.py index e98a12f..05a8ece 100644 --- a/tests/test_watcher_ex.py +++ b/tests/test_watcher_ex.py @@ -14,7 +14,7 @@ import casbin from tests.test_enforcer import get_examples, TestCaseBase - +from unittest import IsolatedAsyncioTestCase class SampleWatcher: def __init__(self): @@ -113,6 +113,103 @@ def start_watch(self): pass +class AsyncSampleWatcher(): + def __init__(self): + self.callback = None + self.notify_message = None + + async def close(self): + pass + + async def set_update_callback(self, callback): + """ + sets the callback function to be called when the policy is updated + :param callable callback: callback(event) + - event: event received from the rabbitmq + :return: + """ + self.callback = callback + + async def update(self, msg): + """ + update the policy + """ + self.notify_message = msg + return True + + async def update_for_add_policy(self, section, ptype, *params): + """ + update for add policy + :param section: section + :param ptype: policy type + :param params: other params + :return: True if updated + """ + message = "called add policy" + return await self.update(message) + + async def update_for_remove_policy(self, section, ptype, *params): + """ + update for remove policy + :param section: section + :param ptype: policy type + :param params: other params + :return: True if updated + """ + message = "called remove policy" + return await self.update(message) + + async def update_for_remove_filtered_policy(self, section, ptype, field_index, *params): + """ + update for remove filtered policy + :param section: section + :param ptype: policy type + :param field_index: field index + :param params: other params + :return: + """ + message = "called remove filtered policy" + return await self.update(message) + + async def update_for_save_policy(self, model: casbin.Model): + """ + update for save policy + :param model: casbin model + :return: + """ + message = "called save policy" + return await self.update(message) + + async def update_for_add_policies(self, section, ptype, *params): + """ + update for add policies + :param section: section + :param ptype: policy type + :param params: other params + :return: + """ + message = "called add policies" + return await self.update(message) + + async def update_for_remove_policies(self, section, ptype, *params): + """ + update for remove policies + :param section: section + :param ptype: policy type + :param params: other params + :return: + """ + message = "called remove policies" + return await self.update(message) + + async def start_watch(self): + """ + starts the watch thread + :return: + """ + pass + + class TestWatcherEx(TestCaseBase): def get_enforcer(self, model=None, adapter=None): return casbin.Enforcer( @@ -187,3 +284,84 @@ def test_auto_notify_disabled(self): e.remove_policies(rules) self.assertEqual(w.notify_message, None) + + +class TestAsyncWatcherEx(IsolatedAsyncioTestCase): + def get_enforcer(self, model=None, adapter=None): + return casbin.AsyncEnforcer( + model, + adapter, + ) + + async def test_auto_notify_enabled(self): + e = self.get_enforcer( + get_examples("basic_model.conf"), + get_examples("basic_policy.csv"), + ) + await e.load_policy() + + w = AsyncSampleWatcher() + e.set_watcher(w) + e.enable_auto_notify_watcher(True) + + await e.save_policy() + self.assertEqual(w.notify_message, "called save policy") + + await e.add_policy("admin", "data1", "read") + self.assertEqual(w.notify_message, "called add policy") + + await e.remove_policy("admin", "data1", "read") + self.assertEqual(w.notify_message, "called remove policy") + + await e.remove_filtered_policy(1, "data1") + self.assertEqual(w.notify_message, "called remove filtered policy") + + rules = [ + ["jack", "data4", "read"], + ["katy", "data4", "write"], + ["leyo", "data4", "read"], + ["ham", "data4", "write"], + ] + await e.add_policies(rules) + self.assertEqual(w.notify_message, "called add policies") + + await e.remove_policies(rules) + self.assertEqual(w.notify_message, "called remove policies") + + + async def test_auto_notify_disabled(self): + e = self.get_enforcer( + get_examples("basic_model.conf"), + get_examples("basic_policy.csv"), + ) + await e.load_policy() + + w = SampleWatcher() + e.set_watcher(w) + e.enable_auto_notify_watcher(False) + + await e.save_policy() + self.assertEqual(w.notify_message, "called save policy") + + w.notify_message = None + + await e.add_policy("admin", "data1", "read") + self.assertEqual(w.notify_message, None) + + await e.remove_policy("admin", "data1", "read") + self.assertEqual(w.notify_message, None) + + await e.remove_filtered_policy(1, "data1") + self.assertEqual(w.notify_message, None) + + rules = [ + ["jack", "data4", "read"], + ["katy", "data4", "write"], + ["leyo", "data4", "read"], + ["ham", "data4", "write"], + ] + await e.add_policies(rules) + self.assertEqual(w.notify_message, None) + + await e.remove_policies(rules) + self.assertEqual(w.notify_message, None) \ No newline at end of file From 6657f1274ce2d7ae478644ad4459409b858df5c0 Mon Sep 17 00:00:00 2001 From: House Date: Thu, 1 Feb 2024 22:09:37 +0200 Subject: [PATCH 3/3] fix: Used black to fix linter error --- tests/test_watcher_ex.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_watcher_ex.py b/tests/test_watcher_ex.py index 05a8ece..3684087 100644 --- a/tests/test_watcher_ex.py +++ b/tests/test_watcher_ex.py @@ -16,6 +16,7 @@ from tests.test_enforcer import get_examples, TestCaseBase from unittest import IsolatedAsyncioTestCase + class SampleWatcher: def __init__(self): self.callback = None @@ -113,7 +114,7 @@ def start_watch(self): pass -class AsyncSampleWatcher(): +class AsyncSampleWatcher: def __init__(self): self.callback = None self.notify_message = None @@ -292,7 +293,7 @@ def get_enforcer(self, model=None, adapter=None): model, adapter, ) - + async def test_auto_notify_enabled(self): e = self.get_enforcer( get_examples("basic_model.conf"), @@ -328,7 +329,6 @@ async def test_auto_notify_enabled(self): await e.remove_policies(rules) self.assertEqual(w.notify_message, "called remove policies") - async def test_auto_notify_disabled(self): e = self.get_enforcer( get_examples("basic_model.conf"), @@ -364,4 +364,4 @@ async def test_auto_notify_disabled(self): self.assertEqual(w.notify_message, None) await e.remove_policies(rules) - self.assertEqual(w.notify_message, None) \ No newline at end of file + self.assertEqual(w.notify_message, None)