diff --git a/src/dbus_fast/proxy_object.py b/src/dbus_fast/proxy_object.py index 5795b33f..a2872b15 100644 --- a/src/dbus_fast/proxy_object.py +++ b/src/dbus_fast/proxy_object.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import inspect import logging @@ -6,7 +8,7 @@ from collections.abc import Coroutine from dataclasses import dataclass from functools import lru_cache -from typing import Callable, Optional, Union +from typing import Callable from . import introspection as intr from . import message_bus @@ -22,7 +24,7 @@ class SignalHandler: """Signal handler.""" - fn: Callable + fn: Callable | Coroutine unpack_variants: bool @@ -57,7 +59,7 @@ def __init__( bus_name: str, path: str, introspection: intr.Interface, - bus: "message_bus.BaseMessageBus", + bus: message_bus.BaseMessageBus, ) -> None: self.bus_name = bus_name self.path = path @@ -65,6 +67,7 @@ def __init__( self.bus = bus self._signal_handlers: dict[str, list[SignalHandler]] = {} self._signal_match_rule = f"type='signal',sender={bus_name},interface={introspection.name},path={path}" + self._background_tasks: set[asyncio.Task] = set() _underscorer1 = re.compile(r"(.)([A-Z][a-z]+)") _underscorer2 = re.compile(r"([a-z0-9])([A-Z])") @@ -76,7 +79,7 @@ def _to_snake_case(member: str) -> str: return BaseProxyInterface._underscorer2.sub(r"\1_\2", subbed).lower() @staticmethod - def _check_method_return(msg: Message, signature: Optional[str] = None): + def _check_method_return(msg: Message, signature: str | None = None): if msg.message_type == MessageType.ERROR: raise DBusError._from_message(msg) if msg.message_type != MessageType.METHOD_RETURN: @@ -137,10 +140,14 @@ def _message_handler(self, msg: Message) -> None: cb_result = handler.fn(*data) if isinstance(cb_result, Coroutine): - asyncio.create_task(cb_result) # noqa: RUF006 + # Save a strong reference to the task so it doesn't get garbage + # collected before it finishes. + task = asyncio.create_task(cb_result) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.remove) def _add_signal(self, intr_signal: intr.Signal, interface: intr.Interface) -> None: - def on_signal_fn(fn: Callable, *, unpack_variants: bool = False): + def on_signal_fn(fn: Callable | Coroutine, *, unpack_variants: bool = False): fn_signature = inspect.signature(fn) if ( len( @@ -182,7 +189,9 @@ def on_signal_fn(fn: Callable, *, unpack_variants: bool = False): SignalHandler(fn, unpack_variants) ) - def off_signal_fn(fn: Callable, *, unpack_variants: bool = False) -> None: + def off_signal_fn( + fn: Callable | Coroutine, *, unpack_variants: bool = False + ) -> None: try: i = self._signal_handlers[intr_signal.name].index( SignalHandler(fn, unpack_variants) @@ -241,8 +250,8 @@ def __init__( self, bus_name: str, path: str, - introspection: Union[intr.Node, str, ET.Element], - bus: "message_bus.BaseMessageBus", + introspection: intr.Node | str | ET.Element, + bus: message_bus.BaseMessageBus, ProxyInterface: type[BaseProxyInterface], ) -> None: assert_object_path_valid(path) @@ -305,7 +314,7 @@ def get_interface(self, name: str) -> BaseProxyInterface: for intr_signal in intr_interface.signals: interface._add_signal(intr_signal, interface) - def get_owner_notify(msg: Message, err: Optional[Exception]) -> None: + def get_owner_notify(msg: Message, err: Exception | None) -> None: if err: logging.error(f'getting name owner for "{name}" failed, {err}') return @@ -334,7 +343,7 @@ def get_owner_notify(msg: Message, err: Optional[Exception]) -> None: self._interfaces[name] = interface return interface - def get_children(self) -> list["BaseProxyObject"]: + def get_children(self) -> list[BaseProxyObject]: """Get the child nodes of this proxy object according to the introspection data.""" if self._children is None: self._children = [ diff --git a/tests/client/test_signals.py b/tests/client/test_signals.py index 85b3fc4c..65299a2d 100644 --- a/tests/client/test_signals.py +++ b/tests/client/test_signals.py @@ -1,3 +1,5 @@ +import asyncio + import pytest from dbus_fast import Message @@ -359,6 +361,76 @@ def kwarg_bad_handler(value, *, bad_kwarg): bus2.disconnect() +@pytest.mark.asyncio +async def test_coro_callback(): + """Test callback for signal with a coroutine.""" + bus1 = await MessageBus().connect() + bus2 = await MessageBus().connect() + + await bus1.request_name("test.signals.name") + service_interface = ExampleInterface() + bus1.export("/test/path", service_interface) + + obj = bus2.get_proxy_object( + "test.signals.name", "/test/path", bus1._introspect_export_path("/test/path") + ) + interface = obj.get_interface(service_interface.name) + + async def ping(): + await bus2.call( + Message( + destination=bus1.unique_name, + interface="org.freedesktop.DBus.Peer", + path="/test/path", + member="Ping", + ) + ) + + kwargs_handler_counter = 0 + kwargs_handler_err = None + kwarg_default_handler_counter = 0 + kwarg_default_handler_err = None + + async def kwargs_handler(value, **_): + nonlocal kwargs_handler_counter + nonlocal kwargs_handler_err + try: + assert value == "hello" + kwargs_handler_counter += 1 + except AssertionError as ex: + kwargs_handler_err = ex + + async def kwarg_default_handler(value, *, _=True): + nonlocal kwarg_default_handler_counter + nonlocal kwarg_default_handler_err + try: + assert value == "hello" + kwarg_default_handler_counter += 1 + except AssertionError as ex: + kwarg_default_handler_err = ex + + interface.on_some_signal(kwargs_handler) + interface.on_some_signal(kwarg_default_handler) + await ping() + + service_interface.SomeSignal() + await ping() + await asyncio.sleep(0) + assert kwargs_handler_err is None + assert kwargs_handler_counter == 1 + assert kwarg_default_handler_err is None + assert kwarg_default_handler_counter == 1 + + def kwarg_bad_handler(value, *, bad_kwarg): + pass + + with pytest.raises(TypeError): + interface.on_some_signal(kwarg_bad_handler) + + bus1.disconnect() + bus2.disconnect() + + @pytest.mark.asyncio async def test_on_signal_type_error(): """Test on callback raises type errors for invalid callbacks."""