Skip to content

Commit

Permalink
feat: automatically unregister methods of destroyed object, add unreg…
Browse files Browse the repository at this point in the history
…ister method (#351)

* feat: automatically unregister methods of destroyed object, add unregister method

* feat: satisfy type checker
  • Loading branch information
JurgenR authored Oct 30, 2024
1 parent b582daa commit d6feedf
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 7 deletions.
46 changes: 42 additions & 4 deletions src/aioslsk/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import inspect
import logging
from typing import Any, Optional, TypeVar, TYPE_CHECKING, Union
from types import MethodType
import weakref

from .room.model import Room, RoomMessage
from .user.model import ChatMessage, User
Expand Down Expand Up @@ -81,7 +83,8 @@
# Internal functions

def on_message(message_class: type[MessageDataclass]):
"""Decorator for methods listening to specific `MessageData` events"""
"""Decorator for methods listening to specific :class:`.MessageData` events
"""
def register(event_func):
event_func._registered_message = message_class
return event_func
Expand All @@ -104,28 +107,53 @@ def build_message_map(obj: object) -> dict[type[MessageDataclass], Callable]:
class EventBus:

def __init__(self):
self._events: dict[type[Event], list[tuple[int, EventListener]]] = {}
self._events: dict[
type[Event],
list[tuple[int, weakref.ReferenceType[EventListener]]]] = {}

def register(self, event_class: type[E], listener: EventListener, priority: int = 100):
"""Registers an event listener to listen on an event class. The order in
which the listeners are called can be managed using the ``priority``
parameter
"""
entry = (priority, listener)
ref_factory: type[weakref.ReferenceType] = weakref.ref
if isinstance(listener, MethodType):
ref_factory = weakref.WeakMethod

entry = (
priority,
ref_factory(listener, lambda ref: self._remove_callback(event_class, ref))
)
try:
self._events[event_class].append(entry)
except KeyError:
self._events[event_class] = [entry, ]

self._events[event_class].sort(key=lambda e: e[0])

def unregister(self, event_class: type[E], listener: EventListener):
"""Unregisters the event listener from event bus"""
if event_class not in self._events:
return

self._events[event_class] = [
listener_def
for listener_def in self._events[event_class]
if listener != listener_def[1]()
]

async def emit(self, event: Event):
try:
listeners = self._events[event.__class__]
except KeyError:
pass
else:
for _, listener in listeners:
for _, listener_ref in listeners:
listener = listener_ref()
# Should never be None, but checked for type compatibility
if not listener: # pragma: no cover
continue

try:
if asyncio.iscoroutinefunction(listener):
await listener(event)
Expand All @@ -138,6 +166,16 @@ async def emit(self, event: Event):
listener, event
)

def _remove_callback(
self, event_class: type[E],
listener_ref: weakref.ReferenceType[EventListener]):

self._events[event_class] = [
listener_def
for listener_def in self._events[event_class]
if listener_ref != listener_def[1]
]


# Public events
class Event:
Expand Down
64 changes: 61 additions & 3 deletions tests/unit/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
from unittest.mock import create_autospec
import weakref


def listener1(event: KickedEvent):
Expand All @@ -19,6 +20,9 @@ async def async_listener(event: KickedEvent):

class DummyClass:

def on_event(self, event: KickedEvent):
pass

@on_message(Login.Response)
def login(self):
pass
Expand Down Expand Up @@ -54,7 +58,9 @@ def test_whenRegisterNonExistingEvent_shouldAddListener(self):

bus.register(KickedEvent, listener1, priority=5)

assert bus._events[KickedEvent] == [(5, listener1), ]
assert bus._events[KickedEvent] == [
(5, weakref.ref(listener1)),
]

def test_whenRegisterExistingEvent_shouldAddListener(self):
bus = EventBus()
Expand All @@ -63,10 +69,62 @@ def test_whenRegisterExistingEvent_shouldAddListener(self):
bus.register(KickedEvent, listener2, priority=4)

assert bus._events[KickedEvent] == [
(4, listener2),
(5, listener1),
(4, weakref.ref(listener2)),
(5, weakref.ref(listener1)),
]

def test_whenRegisterDifferentTypes_shouldAddListener(self):
bus = EventBus()

dummy = DummyClass()
inline_func = lambda event: str(event)

bus.register(KickedEvent, listener1)
bus.register(KickedEvent, dummy.on_event)
bus.register(KickedEvent, inline_func)

assert bus._events[KickedEvent] == [
(100, weakref.ref(listener1)),
(100, weakref.WeakMethod(dummy.on_event)),
(100, weakref.ref(inline_func)),
]

def test_whenObjectIsDestroyed_shouldRemoveListener(self):
bus = EventBus()
dummy = DummyClass()

bus.register(KickedEvent, dummy.on_event)
bus.register(KickedEvent, listener1)

assert bus._events[KickedEvent] == [
(100, weakref.WeakMethod(dummy.on_event)),
(100, weakref.ref(listener1)),
]

del dummy

assert bus._events[KickedEvent] == [
(100, weakref.ref(listener1)),
]

def test_whenUnregisterExistingEvent_shouldRemoveListener(self):
bus = EventBus()

bus.register(KickedEvent, listener1)

assert bus._events[KickedEvent] == [
(100, weakref.ref(listener1)),
]

bus.unregister(KickedEvent, listener1)

assert bus._events[KickedEvent] == []

def test_whenUnregisterNonExistingEvent_shouldDoNothing(self):
bus = EventBus()

bus.unregister(KickedEvent, listener1)

@pytest.mark.asyncio
async def test_whenEmitNoListenersRegister_shouldNotRaise(self):
bus = EventBus()
Expand Down

0 comments on commit d6feedf

Please sign in to comment.