From 39f2850802ca8fe4e80024a52eba5a73a9d0a501 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 6 Mar 2025 11:09:23 -1000 Subject: [PATCH 1/6] fix: ensure proxy object tasks do not get garbage collected prematurely fixes #388 --- src/dbus_fast/proxy_object.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/dbus_fast/proxy_object.py b/src/dbus_fast/proxy_object.py index e9c40b29..cf0c131f 100644 --- a/src/dbus_fast/proxy_object.py +++ b/src/dbus_fast/proxy_object.py @@ -65,6 +65,7 @@ def __init__( self.bus = bus self._signal_handlers: dict[str, list[SignalHandler]] = {} self._signal_match_rule = f"type='signal',sender={bus_name},interface={introspection.name},path={path}" + self._background_tasks: set[asyncio.Task] = set() _underscorer1 = re.compile(r"(.)([A-Z][a-z]+)") _underscorer2 = re.compile(r"([a-z0-9])([A-Z])") @@ -137,7 +138,11 @@ def _message_handler(self, msg: Message) -> None: cb_result = handler.fn(*data) if isinstance(cb_result, Coroutine): - asyncio.create_task(cb_result) + # Save a strong reference to the task so it doesn't get garbage + # collected before it finishes. + task = asyncio.create_task(cb_result) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.remove) def _add_signal(self, intr_signal: intr.Signal, interface: intr.Interface) -> None: def on_signal_fn(fn: Callable, *, unpack_variants: bool = False): From 2d7b69fd4cc40dc02d54c4bc3ee64a6253ce287b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 6 Mar 2025 13:33:01 -1000 Subject: [PATCH 2/6] chore: cleanup typing --- src/dbus_fast/proxy_object.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/dbus_fast/proxy_object.py b/src/dbus_fast/proxy_object.py index fa6df87d..48868fa4 100644 --- a/src/dbus_fast/proxy_object.py +++ b/src/dbus_fast/proxy_object.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import inspect import logging @@ -6,7 +8,7 @@ from collections.abc import Coroutine from dataclasses import dataclass from functools import lru_cache -from typing import Callable, Optional, Union +from typing import Callable from . import introspection as intr from . import message_bus @@ -22,7 +24,7 @@ class SignalHandler: """Signal handler.""" - fn: Callable + fn: Callable | Coroutine unpack_variants: bool @@ -57,7 +59,7 @@ def __init__( bus_name: str, path: str, introspection: intr.Interface, - bus: "message_bus.BaseMessageBus", + bus: message_bus.BaseMessageBus, ) -> None: self.bus_name = bus_name self.path = path @@ -77,7 +79,7 @@ def _to_snake_case(member: str) -> str: return BaseProxyInterface._underscorer2.sub(r"\1_\2", subbed).lower() @staticmethod - def _check_method_return(msg: Message, signature: Optional[str] = None): + def _check_method_return(msg: Message, signature: str | None = None): if msg.message_type == MessageType.ERROR: raise DBusError._from_message(msg) if msg.message_type != MessageType.METHOD_RETURN: @@ -246,8 +248,8 @@ def __init__( self, bus_name: str, path: str, - introspection: Union[intr.Node, str, ET.Element], - bus: "message_bus.BaseMessageBus", + introspection: intr.Node | str | ET.Element, + bus: message_bus.BaseMessageBus, ProxyInterface: type[BaseProxyInterface], ) -> None: assert_object_path_valid(path) @@ -310,7 +312,7 @@ def get_interface(self, name: str) -> BaseProxyInterface: for intr_signal in intr_interface.signals: interface._add_signal(intr_signal, interface) - def get_owner_notify(msg: Message, err: Optional[Exception]) -> None: + def get_owner_notify(msg: Message, err: Exception | None) -> None: if err: logging.error(f'getting name owner for "{name}" failed, {err}') return @@ -339,7 +341,7 @@ def get_owner_notify(msg: Message, err: Optional[Exception]) -> None: self._interfaces[name] = interface return interface - def get_children(self) -> list["BaseProxyObject"]: + def get_children(self) -> list[BaseProxyObject]: """Get the child nodes of this proxy object according to the introspection data.""" if self._children is None: self._children = [ From 779a9b96e19c1756850dde0c21d42a0a97773057 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 6 Mar 2025 13:33:36 -1000 Subject: [PATCH 3/6] chore: cleanup typing --- src/dbus_fast/proxy_object.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/dbus_fast/proxy_object.py b/src/dbus_fast/proxy_object.py index 48868fa4..a2872b15 100644 --- a/src/dbus_fast/proxy_object.py +++ b/src/dbus_fast/proxy_object.py @@ -147,7 +147,7 @@ def _message_handler(self, msg: Message) -> None: task.add_done_callback(self._background_tasks.remove) def _add_signal(self, intr_signal: intr.Signal, interface: intr.Interface) -> None: - def on_signal_fn(fn: Callable, *, unpack_variants: bool = False): + def on_signal_fn(fn: Callable | Coroutine, *, unpack_variants: bool = False): fn_signature = inspect.signature(fn) if ( len( @@ -189,7 +189,9 @@ def on_signal_fn(fn: Callable, *, unpack_variants: bool = False): SignalHandler(fn, unpack_variants) ) - def off_signal_fn(fn: Callable, *, unpack_variants: bool = False) -> None: + def off_signal_fn( + fn: Callable | Coroutine, *, unpack_variants: bool = False + ) -> None: try: i = self._signal_handlers[intr_signal.name].index( SignalHandler(fn, unpack_variants) From 666619aec3eb0d30085533c2ced0037277c3997d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 6 Mar 2025 13:37:45 -1000 Subject: [PATCH 4/6] trace --- src/dbus_fast/proxy_object.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/dbus_fast/proxy_object.py b/src/dbus_fast/proxy_object.py index a2872b15..cf686f39 100644 --- a/src/dbus_fast/proxy_object.py +++ b/src/dbus_fast/proxy_object.py @@ -139,6 +139,7 @@ def _message_handler(self, msg: Message) -> None: data = body cb_result = handler.fn(*data) + raise ValueError("signal handlers cannot return values") if isinstance(cb_result, Coroutine): # Save a strong reference to the task so it doesn't get garbage # collected before it finishes. From 8bc74351695af2af423330eb9f812a84d7321347 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 6 Mar 2025 13:39:40 -1000 Subject: [PATCH 5/6] Revert "trace" This reverts commit 666619aec3eb0d30085533c2ced0037277c3997d. --- src/dbus_fast/proxy_object.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/dbus_fast/proxy_object.py b/src/dbus_fast/proxy_object.py index cf686f39..a2872b15 100644 --- a/src/dbus_fast/proxy_object.py +++ b/src/dbus_fast/proxy_object.py @@ -139,7 +139,6 @@ def _message_handler(self, msg: Message) -> None: data = body cb_result = handler.fn(*data) - raise ValueError("signal handlers cannot return values") if isinstance(cb_result, Coroutine): # Save a strong reference to the task so it doesn't get garbage # collected before it finishes. From 14eead0d41b537c0682c9cbc6f3d0d0d0325c0a4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 6 Mar 2025 13:56:41 -1000 Subject: [PATCH 6/6] chore: add coro tests --- tests/client/test_signals.py | 72 ++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/tests/client/test_signals.py b/tests/client/test_signals.py index 85b3fc4c..65299a2d 100644 --- a/tests/client/test_signals.py +++ b/tests/client/test_signals.py @@ -1,3 +1,5 @@ +import asyncio + import pytest from dbus_fast import Message @@ -359,6 +361,76 @@ def kwarg_bad_handler(value, *, bad_kwarg): bus2.disconnect() +@pytest.mark.asyncio +async def test_coro_callback(): + """Test callback for signal with a coroutine.""" + bus1 = await MessageBus().connect() + bus2 = await MessageBus().connect() + + await bus1.request_name("test.signals.name") + service_interface = ExampleInterface() + bus1.export("/test/path", service_interface) + + obj = bus2.get_proxy_object( + "test.signals.name", "/test/path", bus1._introspect_export_path("/test/path") + ) + interface = obj.get_interface(service_interface.name) + + async def ping(): + await bus2.call( + Message( + destination=bus1.unique_name, + interface="org.freedesktop.DBus.Peer", + path="/test/path", + member="Ping", + ) + ) + + kwargs_handler_counter = 0 + kwargs_handler_err = None + kwarg_default_handler_counter = 0 + kwarg_default_handler_err = None + + async def kwargs_handler(value, **_): + nonlocal kwargs_handler_counter + nonlocal kwargs_handler_err + try: + assert value == "hello" + kwargs_handler_counter += 1 + except AssertionError as ex: + kwargs_handler_err = ex + + async def kwarg_default_handler(value, *, _=True): + nonlocal kwarg_default_handler_counter + nonlocal kwarg_default_handler_err + try: + assert value == "hello" + kwarg_default_handler_counter += 1 + except AssertionError as ex: + kwarg_default_handler_err = ex + + interface.on_some_signal(kwargs_handler) + interface.on_some_signal(kwarg_default_handler) + await ping() + + service_interface.SomeSignal() + await ping() + await asyncio.sleep(0) + assert kwargs_handler_err is None + assert kwargs_handler_counter == 1 + assert kwarg_default_handler_err is None + assert kwarg_default_handler_counter == 1 + + def kwarg_bad_handler(value, *, bad_kwarg): + pass + + with pytest.raises(TypeError): + interface.on_some_signal(kwarg_bad_handler) + + bus1.disconnect() + bus2.disconnect() + + @pytest.mark.asyncio async def test_on_signal_type_error(): """Test on callback raises type errors for invalid callbacks."""