Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added support for async watcher callbacks #340 #341

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 37 additions & 12 deletions casbin/async_internal_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect

from casbin.core_enforcer import CoreEnforcer
from casbin.model import Model, FunctionMap
Expand Down Expand Up @@ -105,8 +106,12 @@ async def save_policy(self):
await self.adapter.save_policy(self.model)

if self.watcher:
if callable(getattr(self.watcher, "update_for_save_policy", None)):
self.watcher.update_for_save_policy(self.model)
update_for_save_policy = getattr(self.watcher, "update_for_save_policy", None)
if callable(update_for_save_policy):
if inspect.iscoroutinefunction(update_for_save_policy):
await update_for_save_policy(self.model)
else:
update_for_save_policy(self.model)
else:
self.watcher.update()

Expand All @@ -122,8 +127,12 @@ async def _add_policy(self, sec, ptype, rule):
return False

if self.watcher and self.auto_notify_watcher:
if callable(getattr(self.watcher, "update_for_add_policy", None)):
self.watcher.update_for_add_policy(sec, ptype, rule)
update_for_add_policy = getattr(self.watcher, "update_for_add_policy", None)
if callable(update_for_add_policy):
if inspect.iscoroutinefunction(update_for_add_policy):
await update_for_add_policy(sec, ptype, rule)
else:
update_for_add_policy(sec, ptype, rule)
else:
self.watcher.update()

Expand All @@ -144,8 +153,12 @@ async def _add_policies(self, sec, ptype, rules):
return False

if self.watcher and self.auto_notify_watcher:
if callable(getattr(self.watcher, "update_for_add_policies", None)):
self.watcher.update_for_add_policies(sec, ptype, rules)
update_for_add_policies = getattr(self.watcher, "update_for_add_policies", None)
if callable(update_for_add_policies):
if inspect.iscoroutinefunction(update_for_add_policies):
await update_for_add_policies(sec, ptype, rules)
else:
update_for_add_policies(sec, ptype, rules)
else:
self.watcher.update()

Expand Down Expand Up @@ -224,8 +237,12 @@ async def _remove_policy(self, sec, ptype, rule):
return False

if self.watcher and self.auto_notify_watcher:
if callable(getattr(self.watcher, "update_for_remove_policy", None)):
self.watcher.update_for_remove_policy(sec, ptype, rule)
update_for_remove_policy = getattr(self.watcher, "update_for_remove_policy", None)
if callable(update_for_remove_policy):
if inspect.iscoroutinefunction(update_for_remove_policy):
await update_for_remove_policy(sec, ptype, rule)
else:
update_for_remove_policy(sec, ptype, rule)
else:
self.watcher.update()

Expand All @@ -246,8 +263,12 @@ async def _remove_policies(self, sec, ptype, rules):
return False

if self.watcher and self.auto_notify_watcher:
if callable(getattr(self.watcher, "update_for_remove_policies", None)):
self.watcher.update_for_remove_policies(sec, ptype, rules)
update_for_remove_policies = getattr(self.watcher, "update_for_remove_policies", None)
if callable(update_for_remove_policies):
if inspect.iscoroutinefunction(update_for_remove_policies):
await update_for_remove_policies(sec, ptype, rules)
else:
update_for_remove_policies(sec, ptype, rules)
else:
self.watcher.update()

Expand All @@ -265,8 +286,12 @@ async def _remove_filtered_policy(self, sec, ptype, field_index, *field_values):
return False

if self.watcher and self.auto_notify_watcher:
if callable(getattr(self.watcher, "update_for_remove_filtered_policy", None)):
self.watcher.update_for_remove_filtered_policy(sec, ptype, field_index, *field_values)
update_for_remove_filtered_policy = getattr(self.watcher, "update_for_remove_filtered_policy", None)
if callable(update_for_remove_filtered_policy):
if inspect.iscoroutinefunction(update_for_remove_filtered_policy):
await update_for_remove_filtered_policy(sec, ptype, field_index, *field_values)
else:
update_for_remove_filtered_policy(sec, ptype, field_index, *field_values)
else:
self.watcher.update()

Expand Down
178 changes: 178 additions & 0 deletions tests/test_watcher_ex.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import casbin
from tests.test_enforcer import get_examples, TestCaseBase
from unittest import IsolatedAsyncioTestCase


class SampleWatcher:
Expand Down Expand Up @@ -113,6 +114,103 @@ def start_watch(self):
pass


class AsyncSampleWatcher:
def __init__(self):
self.callback = None
self.notify_message = None

async def close(self):
pass

async def set_update_callback(self, callback):
"""
sets the callback function to be called when the policy is updated
:param callable callback: callback(event)
- event: event received from the rabbitmq
:return:
"""
self.callback = callback

async def update(self, msg):
"""
update the policy
"""
self.notify_message = msg
return True

async def update_for_add_policy(self, section, ptype, *params):
"""
update for add policy
:param section: section
:param ptype: policy type
:param params: other params
:return: True if updated
"""
message = "called add policy"
return await self.update(message)

async def update_for_remove_policy(self, section, ptype, *params):
"""
update for remove policy
:param section: section
:param ptype: policy type
:param params: other params
:return: True if updated
"""
message = "called remove policy"
return await self.update(message)

async def update_for_remove_filtered_policy(self, section, ptype, field_index, *params):
"""
update for remove filtered policy
:param section: section
:param ptype: policy type
:param field_index: field index
:param params: other params
:return:
"""
message = "called remove filtered policy"
return await self.update(message)

async def update_for_save_policy(self, model: casbin.Model):
"""
update for save policy
:param model: casbin model
:return:
"""
message = "called save policy"
return await self.update(message)

async def update_for_add_policies(self, section, ptype, *params):
"""
update for add policies
:param section: section
:param ptype: policy type
:param params: other params
:return:
"""
message = "called add policies"
return await self.update(message)

async def update_for_remove_policies(self, section, ptype, *params):
"""
update for remove policies
:param section: section
:param ptype: policy type
:param params: other params
:return:
"""
message = "called remove policies"
return await self.update(message)

async def start_watch(self):
"""
starts the watch thread
:return:
"""
pass


class TestWatcherEx(TestCaseBase):
def get_enforcer(self, model=None, adapter=None):
return casbin.Enforcer(
Expand Down Expand Up @@ -187,3 +285,83 @@ def test_auto_notify_disabled(self):

e.remove_policies(rules)
self.assertEqual(w.notify_message, None)


class TestAsyncWatcherEx(IsolatedAsyncioTestCase):
def get_enforcer(self, model=None, adapter=None):
return casbin.AsyncEnforcer(
model,
adapter,
)

async def test_auto_notify_enabled(self):
e = self.get_enforcer(
get_examples("basic_model.conf"),
get_examples("basic_policy.csv"),
)
await e.load_policy()

w = AsyncSampleWatcher()
e.set_watcher(w)
e.enable_auto_notify_watcher(True)

await e.save_policy()
self.assertEqual(w.notify_message, "called save policy")

await e.add_policy("admin", "data1", "read")
self.assertEqual(w.notify_message, "called add policy")

await e.remove_policy("admin", "data1", "read")
self.assertEqual(w.notify_message, "called remove policy")

await e.remove_filtered_policy(1, "data1")
self.assertEqual(w.notify_message, "called remove filtered policy")

rules = [
["jack", "data4", "read"],
["katy", "data4", "write"],
["leyo", "data4", "read"],
["ham", "data4", "write"],
]
await e.add_policies(rules)
self.assertEqual(w.notify_message, "called add policies")

await e.remove_policies(rules)
self.assertEqual(w.notify_message, "called remove policies")

async def test_auto_notify_disabled(self):
e = self.get_enforcer(
get_examples("basic_model.conf"),
get_examples("basic_policy.csv"),
)
await e.load_policy()

w = SampleWatcher()
e.set_watcher(w)
e.enable_auto_notify_watcher(False)

await e.save_policy()
self.assertEqual(w.notify_message, "called save policy")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether save_policy should also depend on auto_notify_watcher ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hsluoyz I saw the same in go's casbin.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any updates on this issue?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@leeqvip it should behave the same way as Go

w.notify_message = None

await e.add_policy("admin", "data1", "read")
self.assertEqual(w.notify_message, None)

await e.remove_policy("admin", "data1", "read")
self.assertEqual(w.notify_message, None)

await e.remove_filtered_policy(1, "data1")
self.assertEqual(w.notify_message, None)

rules = [
["jack", "data4", "read"],
["katy", "data4", "write"],
["leyo", "data4", "read"],
["ham", "data4", "write"],
]
await e.add_policies(rules)
self.assertEqual(w.notify_message, None)

await e.remove_policies(rules)
self.assertEqual(w.notify_message, None)