From eb32a889c8b63807cc91572a9b7988d502781391 Mon Sep 17 00:00:00 2001 From: noahhusby <32528627+noahhusby@users.noreply.github.com> Date: Tue, 24 Sep 2024 11:31:59 -0400 Subject: [PATCH] refact: revamp callback --- aiorussound/__init__.py | 3 +- aiorussound/connection.py | 24 ++++--- aiorussound/models.py | 52 +++++++++++--- aiorussound/rio.py | 143 +++++++++++++++++--------------------- examples/basic.py | 3 +- examples/subscribe.py | 31 +++++++++ 6 files changed, 153 insertions(+), 103 deletions(-) create mode 100644 examples/subscribe.py diff --git a/aiorussound/__init__.py b/aiorussound/__init__.py index 7046921..09ca821 100644 --- a/aiorussound/__init__.py +++ b/aiorussound/__init__.py @@ -1,4 +1,5 @@ """Asynchronous Python client for Russound RIO.""" + from .exceptions import ( CommandError, UncachedVariableError, @@ -21,5 +22,5 @@ "RussoundTcpConnectionHandler", "ZoneProperties", "SourceProperties", - "RussoundMessage" + "RussoundMessage", ] diff --git a/aiorussound/connection.py b/aiorussound/connection.py index 22b1e69..f1dd64e 100644 --- a/aiorussound/connection.py +++ b/aiorussound/connection.py @@ -21,7 +21,12 @@ def _process_response(res: bytes) -> Optional[RussoundMessage]: """Process an incoming string of bytes into a RussoundMessage""" try: # Attempt to decode in Latin and re-encode in UTF-8 to support international characters - str_res = res.decode(encoding="iso-8859-1").encode(encoding="utf-8").decode(encoding="utf-8").strip() + str_res = ( + res.decode(encoding="iso-8859-1") + .encode(encoding="utf-8") + .decode(encoding="utf-8") + .strip() + ) except UnicodeDecodeError as e: _LOGGER.warning("Failed to decode Russound response %s", res, e) return None @@ -36,7 +41,9 @@ def _process_response(res: bytes) -> Optional[RussoundMessage]: return RussoundMessage(tag, None, None, None, None, None) p = m.groupdict() value = p["value"] or p["value_only"] - return RussoundMessage(tag, p["variable"], value, p["zone"], p["controller"], p["source"]) + return RussoundMessage( + tag, p["variable"], value, p["zone"], p["controller"], p["source"] + ) class RussoundConnectionHandler: @@ -94,14 +101,15 @@ def remove_message_callback(self, callback) -> None: """Removes a previously registered callback.""" self._message_callback.remove(callback) - def _on_msg_recv(self, msg: RussoundMessage) -> None: + async def _on_msg_recv(self, msg: RussoundMessage) -> None: for callback in self._message_callback: - callback(msg) + await callback(msg) class RussoundTcpConnectionHandler(RussoundConnectionHandler): - - def __init__(self, loop: AbstractEventLoop, host: str, port: int = DEFAULT_PORT) -> None: + def __init__( + self, loop: AbstractEventLoop, host: str, port: int = DEFAULT_PORT + ) -> None: """Initialize the Russound object using the event loop, host and port provided. """ @@ -129,7 +137,7 @@ async def close(self): self._set_connected(False) async def _ioloop( - self, reader: StreamReader, writer: StreamWriter, reconnect: bool + self, reader: StreamReader, writer: StreamWriter, reconnect: bool ) -> None: queue_future = ensure_future(self._cmd_queue.get()) net_future = ensure_future(reader.readline()) @@ -148,7 +156,7 @@ async def _ioloop( try: msg = _process_response(response) if msg: - self._on_msg_recv(msg) + await self._on_msg_recv(msg) if msg.tag == "S" and last_command_future: last_command_future.set_result(msg.value) last_command_future = None diff --git a/aiorussound/models.py b/aiorussound/models.py index fe13e8a..1b58d89 100644 --- a/aiorussound/models.py +++ b/aiorussound/models.py @@ -1,5 +1,7 @@ """Models for aiorussound.""" + from dataclasses import dataclass, field +from enum import StrEnum from typing import Optional from mashumaro import field_options @@ -9,6 +11,7 @@ @dataclass class RussoundMessage: """Incoming russound message.""" + tag: str variable: Optional[str] = None value: Optional[str] = None @@ -26,18 +29,32 @@ class ZoneProperties(DataClassORJSONMixin): treble: str = field(metadata=field_options(alias="treble"), default="0") balance: str = field(metadata=field_options(alias="balance"), default="0") loudness: str = field(metadata=field_options(alias="loudness"), default="OFF") - turn_on_volume: str = field(metadata=field_options(alias="turnonvolume"), default="20") - do_not_disturb: str = field(metadata=field_options(alias="donotdisturb"), default="OFF") + turn_on_volume: str = field( + metadata=field_options(alias="turnonvolume"), default="20" + ) + do_not_disturb: str = field( + metadata=field_options(alias="donotdisturb"), default="OFF" + ) party_mode: str = field(metadata=field_options(alias="partymode"), default="OFF") status: str = field(metadata=field_options(alias="status"), default="OFF") is_mute: str = field(metadata=field_options(alias="mute"), default="OFF") - shared_source: str = field(metadata=field_options(alias="sharedsource"), default="OFF") - last_error: Optional[str] = field(metadata=field_options(alias="lasterror"), default=None) + shared_source: str = field( + metadata=field_options(alias="sharedsource"), default="OFF" + ) + last_error: Optional[str] = field( + metadata=field_options(alias="lasterror"), default=None + ) page: Optional[str] = field(metadata=field_options(alias="page"), default=None) - sleep_time_default: Optional[str] = field(metadata=field_options(alias="sleeptimedefault"), default=None) - sleep_time_remaining: Optional[str] = field(metadata=field_options(alias="sleeptimeremaining"), default=None) + sleep_time_default: Optional[str] = field( + metadata=field_options(alias="sleeptimedefault"), default=None + ) + sleep_time_remaining: Optional[str] = field( + metadata=field_options(alias="sleeptimeremaining"), default=None + ) enabled: str = field(metadata=field_options(alias="enabled"), default="False") - current_source: str = field(metadata=field_options(alias="currentsource"), default="1") + current_source: str = field( + metadata=field_options(alias="currentsource"), default="1" + ) @dataclass @@ -46,14 +63,20 @@ class SourceProperties(DataClassORJSONMixin): type: str = field(metadata=field_options(alias="type"), default=None) channel: str = field(metadata=field_options(alias="channel"), default=None) - cover_art_url: str = field(metadata=field_options(alias="covertarturl"), default=None) + cover_art_url: str = field( + metadata=field_options(alias="covertarturl"), default=None + ) channel_name: str = field(metadata=field_options(alias="channelname"), default=None) genre: str = field(metadata=field_options(alias="genre"), default=None) artist_name: str = field(metadata=field_options(alias="artistname"), default=None) album_name: str = field(metadata=field_options(alias="albumname"), default=None) - playlist_name: str = field(metadata=field_options(alias="playlistname"), default=None) + playlist_name: str = field( + metadata=field_options(alias="playlistname"), default=None + ) song_name: str = field(metadata=field_options(alias="songname"), default=None) - program_service_name: str = field(metadata=field_options(alias="programservicename"), default=None) + program_service_name: str = field( + metadata=field_options(alias="programservicename"), default=None + ) radio_text: str = field(metadata=field_options(alias="radiotext"), default=None) shuffle_mode: str = field(metadata=field_options(alias="shufflemode"), default=None) repeat_mode: str = field(metadata=field_options(alias="repeatmode"), default=None) @@ -63,4 +86,11 @@ class SourceProperties(DataClassORJSONMixin): bit_rate: str = field(metadata=field_options(alias="bitrate"), default=None) bit_depth: str = field(metadata=field_options(alias="bitdepth"), default=None) play_time: str = field(metadata=field_options(alias="playtime"), default=None) - track_time: str = field(metadata=field_options(alias="tracktime"), default=None) \ No newline at end of file + track_time: str = field(metadata=field_options(alias="tracktime"), default=None) + + +class CallbackType(StrEnum): + """Callback type.""" + + STATE = "state" + CONNECTION = "connection" diff --git a/aiorussound/rio.py b/aiorussound/rio.py index 9b59691..5200c8a 100644 --- a/aiorussound/rio.py +++ b/aiorussound/rio.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import logging from typing import Any, Coroutine @@ -19,7 +20,12 @@ UncachedVariableError, UnsupportedFeatureError, ) -from aiorussound.models import RussoundMessage, ZoneProperties, SourceProperties +from aiorussound.models import ( + RussoundMessage, + ZoneProperties, + SourceProperties, + CallbackType, +) from aiorussound.util import ( controller_device_str, get_max_zones, @@ -35,21 +41,46 @@ class RussoundClient: """Manages the RIO connection to a Russound device.""" - def __init__( - self, connection_handler: RussoundConnectionHandler - ) -> None: + def __init__(self, connection_handler: RussoundConnectionHandler) -> None: """Initialize the Russound object using the event loop, host and port provided. """ self.connection_handler = connection_handler self.connection_handler.add_message_callback(self._on_msg_recv) self._state: dict[str, dict[str, str]] = {} - self._callbacks: dict[str, list[Any]] = {} + self._state_update_callbacks: list[Any] = [] self._watched_devices: dict[str, bool] = {} self._controllers: dict[int, Controller] = {} self.sources: dict[int, Source] = {} self.rio_version: str | None = None + async def register_state_update_callbacks(self, callback: Any): + """Register state update callback.""" + self._state_update_callbacks.append(callback) + await callback(self, CallbackType.STATE) + + def unregister_state_update_callbacks(self, callback: Any): + """Unregister state update callback.""" + if callback in self._state_update_callbacks: + self._state_update_callbacks.remove(callback) + + def clear_state_update_callbacks(self): + """Clear state update callbacks.""" + self._state_update_callbacks.clear() + + async def do_state_update_callbacks( + self, callback_type: CallbackType = CallbackType.STATE + ): + """Call state update callbacks.""" + if not self._state_update_callbacks: + return + callbacks = set() + for callback in self._state_update_callbacks: + callbacks.add(callback(self, callback_type)) + + if callbacks: + await asyncio.gather(*callbacks) + def _retrieve_cached_variable(self, device_str: str, key: str) -> str: """Retrieve the cache state of the named variable for a particular device. If the variable has not been cached then the UncachedVariable @@ -62,7 +93,9 @@ def _retrieve_cached_variable(self, device_str: str, key: str) -> str: except KeyError: raise UncachedVariableError - def _store_cached_variable(self, device_str: str, key: str, value: str) -> None: + async def _store_cached_variable( + self, device_str: str, key: str, value: str + ) -> None: """Store the current known value of a device variable into the cache. Calls any device callbacks. """ @@ -71,43 +104,21 @@ def _store_cached_variable(self, device_str: str, key: str, value: str) -> None: zone_state[key] = value _LOGGER.debug("Cache store %s.%s = %s", device_str, key, value) # Handle callbacks - for callback in self._callbacks.get(device_str, []): - callback(device_str, key, value) - # Handle source callback - if device_str[0] == "S": - for controller in self._controllers.values(): - for zone in controller.zones.values(): - source = zone.fetch_current_source() - if source and source.device_str() == device_str: - for callback in self._callbacks.get(zone.device_str(), []): - callback(device_str, key, value) - - def _on_msg_recv(self, msg: RussoundMessage) -> None: + await self.do_state_update_callbacks() + + async def _on_msg_recv(self, msg: RussoundMessage) -> None: if msg.source: source_id = int(msg.source) - self._store_cached_variable( + await self._store_cached_variable( source_device_str(source_id), msg.variable, msg.value ) elif msg.zone: controller_id = int(msg.controller) zone_id = int(msg.zone) - self._store_cached_variable( + await self._store_cached_variable( zone_device_str(controller_id, zone_id), msg.variable, msg.value ) - def add_callback(self, device_str: str, callback) -> None: - """Register a callback to be called whenever a device variable changes. - The callback will be passed three arguments: the device_str, the variable - name and the variable value. - """ - callbacks = self._callbacks.setdefault(device_str, []) - callbacks.append(callback) - - def remove_callback(self, callback) -> None: - """Remove a previously registered callback.""" - for callbacks in self._callbacks.values(): - callbacks.remove(callback) - async def connect(self, reconnect=True) -> None: """Connect to the controller and start processing responses.""" await self.connection_handler.connect(reconnect=reconnect) @@ -126,7 +137,7 @@ async def close(self) -> None: await self.connection_handler.close() async def set_variable( - self, device_str: str, key: str, value: str + self, device_str: str, key: str, value: str ) -> Coroutine[Any, Any, str]: """Set a zone variable to a new value.""" return self.connection_handler.send(f'SET {device_str}.{key}="{value}"') @@ -173,7 +184,7 @@ async def enumerate_controllers(self) -> dict[int, Controller]: pass firmware_version = None if is_feature_supported( - self.rio_version, FeatureFlag.PROPERTY_FIRMWARE_VERSION + self.rio_version, FeatureFlag.PROPERTY_FIRMWARE_VERSION ): firmware_version = await self.get_variable( device_str, "firmwareVersion" @@ -237,13 +248,13 @@ class Controller: """Uniquely identifies a controller.""" def __init__( - self, - instance: RussoundClient, - parent_controller: Controller, - controller_id: int, - mac_address: str, - controller_type: str, - firmware_version: str, + self, + instance: RussoundClient, + parent_controller: Controller, + controller_id: int, + mac_address: str, + controller_type: str, + firmware_version: str, ) -> None: """Initialize the controller.""" self.instance = instance @@ -266,8 +277,8 @@ def __str__(self) -> str: def __eq__(self, other: object) -> bool: """Equality check.""" return ( - hasattr(other, "controller_id") - and other.controller_id == self.controller_id + hasattr(other, "controller_id") + and other.controller_id == self.controller_id ) def __hash__(self) -> int: @@ -289,14 +300,6 @@ async def _init_zones(self) -> None: except CommandError: break - def add_callback(self, callback) -> None: - """Add a callback function to be called when a zone is changed.""" - self.instance.add_callback(controller_device_str(self.controller_id), callback) - - def remove_callback(self, callback) -> None: - """Remove a callback function to be called when a zone is changed.""" - self.instance.remove_callback(callback) - class Zone: """Uniquely identifies a zone @@ -307,7 +310,7 @@ class Zone: """ def __init__( - self, instance: RussoundClient, controller: Controller, zone_id: int, name: str + self, instance: RussoundClient, controller: Controller, zone_id: int, name: str ) -> None: """Initialize a zone object.""" self.instance = instance @@ -330,10 +333,10 @@ def __str__(self) -> str: def __eq__(self, other: object) -> bool: """Equality check.""" return ( - hasattr(other, "zone_id") - and hasattr(other, "controller") - and other.zone_id == self.zone_id - and other.controller == self.controller + hasattr(other, "zone_id") + and hasattr(other, "controller") + and other.zone_id == self.zone_id + and other.controller == self.controller ) def __hash__(self) -> int: @@ -358,14 +361,6 @@ async def unwatch(self) -> str: """Remove a zone from the watchlist.""" return await self.instance.unwatch(self.device_str()) - def add_callback(self, callback) -> None: - """Adds a callback function to be called when a zone is changed.""" - self.instance.add_callback(self.device_str(), callback) - - def remove_callback(self, callback) -> None: - """Remove a zone from the watchlist.""" - self.instance.remove_callback(callback) - async def send_event(self, event_name, *args) -> str: """Send an event to a zone.""" cmd = f"EVENT {self.device_str()}!{event_name} {" ".join(str(x) for x in args)}" @@ -439,9 +434,7 @@ async def select_source(self, source: int) -> str: class Source: """Uniquely identifies a Source.""" - def __init__( - self, instance: RussoundClient, source_id: int, name: str - ) -> None: + def __init__(self, instance: RussoundClient, source_id: int, name: str) -> None: """Initialize a Source.""" self.instance = instance self.source_id = int(source_id) @@ -461,10 +454,7 @@ def __str__(self) -> str: def __eq__(self, other: object) -> bool: """Equality check.""" - return ( - hasattr(other, "source_id") - and other.source_id == self.source_id - ) + return hasattr(other, "source_id") and other.source_id == self.source_id def __hash__(self) -> int: """Hash the current configuration of the source.""" @@ -476,14 +466,6 @@ def device_str(self) -> str: """ return source_device_str(self.source_id) - def add_callback(self, callback: Any) -> None: - """Add a callback function to the zone.""" - self.instance.add_callback(self.device_str(), callback) - - def remove_callback(self, callback: Any) -> None: - """Remove a callback from the source.""" - self.instance.remove_callback(callback) - async def watch(self) -> str: """Add a source to the watchlist. Sources on the watchlist will push all @@ -509,4 +491,3 @@ def _get(self, variable: str) -> str: @property def properties(self) -> SourceProperties: return SourceProperties.from_dict(self.instance.get_cache(self.device_str())) - diff --git a/examples/basic.py b/examples/basic.py index fef3a68..a53c848 100644 --- a/examples/basic.py +++ b/examples/basic.py @@ -13,7 +13,7 @@ sys.path.insert(1, os.path.join(os.path.dirname(__file__), "..")) -from aiorussound import Zone, RussoundClient +from aiorussound import RussoundClient _LOGGER = logging.getLogger(__package__) @@ -49,7 +49,6 @@ async def demo(loop: AbstractEventLoop, host: str) -> None: for source_id, source in rus.sources.items(): print(source.properties) - while True: await asyncio.sleep(1) diff --git a/examples/subscribe.py b/examples/subscribe.py new file mode 100644 index 0000000..5c4605a --- /dev/null +++ b/examples/subscribe.py @@ -0,0 +1,31 @@ +import asyncio + +from aiorussound import RussoundTcpConnectionHandler, RussoundClient +from aiorussound.models import CallbackType + +HOST = "192.168.20.17" +PORT = 4999 + + +async def on_state_change(client: RussoundClient, callback_type: CallbackType): + """Called when new information is received.""" + print(f"Callback Type: {callback_type}") + print(f"Sources: {client.sources}") + + +async def main(): + """Subscribe demo entrypoint.""" + conn_handler = RussoundTcpConnectionHandler(asyncio.get_running_loop(), HOST, PORT) + client = RussoundClient(conn_handler) + + await client.register_state_update_callbacks(on_state_change) + await client.connect() + await client.enumerate_controllers() + await client.init_sources() + + # Play media using the unit's front controls or Russound app + await asyncio.sleep(60) + + +if __name__ == "__main__": + asyncio.run(main())