diff --git a/aiorussound/connection.py b/aiorussound/connection.py index ef216d8..8c85638 100644 --- a/aiorussound/connection.py +++ b/aiorussound/connection.py @@ -1,189 +1,49 @@ import asyncio import logging from abc import abstractmethod -from asyncio import AbstractEventLoop, Queue, StreamReader, StreamWriter -from typing import Any, Optional +from asyncio import StreamReader +from typing import Optional -from aiorussound import CommandError from aiorussound.const import ( DEFAULT_PORT, - RECONNECT_DELAY, - RESPONSE_REGEX, - KEEP_ALIVE_INTERVAL, ) -from aiorussound.models import RussoundMessage, MessageType _LOGGER = logging.getLogger(__package__) -# Maintain compat with various 3.x async changes -if hasattr(asyncio, "ensure_future"): - ensure_future = asyncio.ensure_future -else: - ensure_future = getattr(asyncio, "async") - - -def _process_response(res: bytes) -> Optional[RussoundMessage]: - """Process an incoming string of bytes into a RussoundMessage""" - try: - # Attempt to decode in Latin and re-encode in UTF-8 to support international characters - str_res = ( - res.decode(encoding="iso-8859-1") - .encode(encoding="utf-8") - .decode(encoding="utf-8") - .strip() - ) - except UnicodeDecodeError as e: - _LOGGER.warning("Failed to decode Russound response %s", res, e) - return None - if not str_res: - return None - if len(str_res) == 1 and str_res[0] == "S": - return RussoundMessage(MessageType.STATE, None, None, None) - tag, payload = str_res[0], str_res[2:] - if tag == "E": - _LOGGER.debug("Device responded with error: %s", payload) - return RussoundMessage(tag, None, None, payload) - m = RESPONSE_REGEX.match(payload.strip()) - if not m: - return RussoundMessage(tag, None, None, None) - return RussoundMessage(tag, m.group(1) or None, m.group(2), m.group(3)) - class RussoundConnectionHandler: - def __init__(self, loop: AbstractEventLoop) -> None: - self._loop = loop - self._connection_started: bool = False - self.connected: bool = False - self._message_callback: list[Any] = [] - self._connection_callbacks: list[Any] = [] - self._cmd_queue: Queue = Queue() - - @abstractmethod - async def close(self): - raise NotImplementedError + def __init__(self) -> None: + self.reader: Optional[StreamReader] = None async def send(self, cmd: str) -> None: """Send a command to the Russound client.""" - if not self.connected: - raise CommandError("Not connected to device.") - await self._cmd_queue.put(cmd) + pass + # if not self.connected: + # raise CommandError("Not connected to device.") @abstractmethod - async def connect(self, reconnect=True) -> None: + async def connect(self) -> None: raise NotImplementedError - async def _keep_alive(self) -> None: - while True: - await asyncio.sleep(KEEP_ALIVE_INTERVAL) # 15 minutes - _LOGGER.debug("Sending keep alive to device") - await self.send("VERSION") - - def _set_connected(self, connected: bool): - self.connected = connected - for callback in self._connection_callbacks: - callback(connected) - - def add_connection_callback(self, callback) -> None: - """Register a callback to be called whenever the instance is connected/disconnected. - The callback will be passed one argument: connected: bool. - """ - self._connection_callbacks.append(callback) - - def remove_connection_callback(self, callback) -> None: - """Removes a previously registered callback.""" - self._connection_callbacks.remove(callback) - - def add_message_callback(self, callback) -> None: - """Register a callback to be called whenever the controller sends a message. - The callback will be passed one argument: msg: str. - """ - self._message_callback.append(callback) - - def remove_message_callback(self, callback) -> None: - """Removes a previously registered callback.""" - self._message_callback.remove(callback) - - async def _on_msg_recv(self, msg: RussoundMessage) -> None: - for callback in self._message_callback: - await callback(msg) - class RussoundTcpConnectionHandler(RussoundConnectionHandler): - def __init__( - self, loop: AbstractEventLoop, host: str, port: int = DEFAULT_PORT - ) -> None: + def __init__(self, host: str, port: int = DEFAULT_PORT) -> None: """Initialize the Russound object using the event loop, host and port provided. """ - super().__init__(loop) + super().__init__() self.host = host self.port = port - self._ioloop_future = None + self.writer = None - async def connect(self, reconnect=True) -> None: - self._connection_started = True - _LOGGER.info("Connecting to %s:%s", self.host, self.port) + async def connect(self) -> None: + _LOGGER.debug("Connecting to %s:%s", self.host, self.port) reader, writer = await asyncio.open_connection(self.host, self.port) - self._ioloop_future = ensure_future(self._ioloop(reader, writer, reconnect)) - self._set_connected(True) - - async def close(self): - """Disconnect from the controller.""" - self._connection_started = False - _LOGGER.info("Closing connection to %s:%s", self.host, self.port) - self._ioloop_future.cancel() - try: - await self._ioloop_future - except asyncio.CancelledError: - pass - self._set_connected(False) + self.reader = reader + self.writer = writer - async def _ioloop( - self, reader: StreamReader, writer: StreamWriter, reconnect: bool - ) -> None: - queue_future = ensure_future(self._cmd_queue.get()) - net_future = ensure_future(reader.readline()) - keep_alive_task = asyncio.create_task(self._keep_alive()) - - try: - _LOGGER.debug("Starting IO loop") - while True: - done, _ = await asyncio.wait( - [queue_future, net_future], return_when=asyncio.FIRST_COMPLETED - ) - - if net_future in done: - response = net_future.result() - msg = _process_response(response) - if msg: - await self._on_msg_recv(msg) - net_future = ensure_future(reader.readline()) - - if queue_future in done: - cmd = queue_future.result() - writer.write(bytearray(f"{cmd}\r", "utf-8")) - await writer.drain() - queue_future = ensure_future(self._cmd_queue.get()) - except asyncio.CancelledError: - _LOGGER.debug("IO loop cancelled") - self._set_connected(False) - raise - except asyncio.TimeoutError: - _LOGGER.warning("Connection to Russound client timed out") - except ConnectionResetError: - _LOGGER.warning("Connection to Russound client reset") - except Exception: - _LOGGER.exception("Unhandled exception in IO loop") - self._set_connected(False) - raise - finally: - _LOGGER.debug("Cancelling all tasks...") - writer.close() - queue_future.cancel() - net_future.cancel() - keep_alive_task.cancel() - self._set_connected(False) - if reconnect and self._connection_started: - _LOGGER.info("Retrying connection to Russound client in 5s") - await asyncio.sleep(RECONNECT_DELAY) - await self.connect(reconnect) + async def send(self, cmd: str) -> None: + """Send a command to the Russound client.""" + await super().send(cmd) + self.writer.write(bytearray(f"{cmd}\r", "utf-8")) + await self.writer.drain() diff --git a/aiorussound/rio.py b/aiorussound/rio.py index 015033c..995ae89 100644 --- a/aiorussound/rio.py +++ b/aiorussound/rio.py @@ -4,7 +4,6 @@ import asyncio import logging -import re from asyncio import Future, Task, AbstractEventLoop, Queue from dataclasses import field, dataclass from typing import Any, Coroutine, Optional @@ -16,6 +15,9 @@ MINIMUM_API_SUPPORT, FeatureFlag, MAX_RNET_CONTROLLERS, + RESPONSE_REGEX, + KEEP_ALIVE_INTERVAL, + TIMEOUT, ) from aiorussound.exceptions import ( CommandError, @@ -27,6 +29,7 @@ CallbackType, Source, Zone, + MessageType, ) from aiorussound.util import ( controller_device_str, @@ -36,6 +39,7 @@ zone_device_str, is_rnet_capable, get_max_zones, + map_rio_to_dict, ) _LOGGER = logging.getLogger(__package__) @@ -49,22 +53,25 @@ def __init__(self, connection_handler: RussoundConnectionHandler) -> None: provided. """ self.connection_handler = connection_handler - self.connection_handler.add_message_callback(self._on_msg_recv) self._loop: AbstractEventLoop = asyncio.get_running_loop() self._subscriptions: dict[str, Any] = {} self.connect_result: Future | None = None self.connect_task: Task | None = None + self._reconnect_task: Optional[Task] = None self._state_update_callbacks: list[Any] = [] self.controllers: dict[int, Controller] = {} self.sources: dict[int, Source] = {} self.rio_version: str | None = None self.state = {} self._futures: Queue = Queue() + self._attempt_reconnection = False + self._do_state_update = False async def register_state_update_callbacks(self, callback: Any): """Register state update callback.""" self._state_update_callbacks.append(callback) - await callback(self, CallbackType.STATE) + if self._do_state_update: + await callback(self, CallbackType.STATE) def unregister_state_update_callbacks(self, callback: Any): """Unregister state update callback.""" @@ -94,116 +101,213 @@ async def request(self, cmd: str): await self._futures.put(future) try: await self.connection_handler.send(cmd) - except (CommandError, RussoundError) as ex: + except Exception as ex: _ = await self._futures.get() future.set_exception(ex) return await future - async def _on_msg_recv(self, msg: RussoundMessage) -> None: - if msg.type == "S": - future: Future = await self._futures.get() - future.set_result(msg.value) - elif msg.type == "E": - future: Future = await self._futures.get() - future.set_exception(CommandError) - if msg.branch and msg.leaf and msg.type == "N": - # Map the RIO syntax to a state dict - path = re.findall(r"\w+\[?\d*]?", msg.branch) - current = self.state - for part in path: - match = re.match(r"(\w+)\[(\d+)]", part) - if match: - key, index = match.groups() - index = int(index) - if key not in current: - current[key] = {} - if index not in current[key]: - current[key][index] = {} - current = current[key][index] - else: - if part not in current: - current[part] = {} - current = current[part] - - # Set the leaf and value in the final dictionary location - current[msg.leaf] = msg.value - subscription = self._subscriptions.get(msg.branch) - if subscription: - await subscription() - async def connect(self) -> None: """Connect to the controller and start processing responses.""" if not self.is_connected(): self.connect_result = self._loop.create_future() - self.connect_task = asyncio.create_task( - self.connect_handler(self.connect_result) + self._reconnect_task = asyncio.create_task( + self._reconnect_handler(self.connect_result) ) return await self.connect_result + async def disconnect(self) -> None: + """Disconnect from the Russound controller.""" + if self.is_connected(): + self._attempt_reconnection = False + self.connect_task.cancel() + try: + await self.connect_task + except asyncio.CancelledError: + pass + def is_connected(self) -> bool: """Return True if device is connected.""" return self.connect_task is not None and not self.connect_task.done() - async def connect_handler(self, res): - await self.connection_handler.connect(reconnect=True) - self.rio_version = await self.request("VERSION") - if not is_fw_version_higher(self.rio_version, MINIMUM_API_SUPPORT): - await self.connection_handler.close() - raise UnsupportedFeatureError( - f"Russound RIO API v{self.rio_version} is not supported. The minimum " - f"supported version is v{MINIMUM_API_SUPPORT}" - ) - _LOGGER.info("Connected (Russound RIO v%s})", self.rio_version) - - # Fetch parent controller - parent_controller = await self._load_controller(1) - if not parent_controller: - raise RussoundError("No primary controller found.") - - self.controllers[1] = parent_controller - - # Only search for daisy-chained controllers if the parent supports RNET - if is_rnet_capable(parent_controller.controller_type): - for controller_id in range(2, MAX_RNET_CONTROLLERS + 1): - controller = await self._load_controller(controller_id) - if controller: - self.controllers[controller_id] = controller - - subscribe_state_updates = {self.subscribe(self._async_handle_system, "System")} - - # Load source structure - for source_id in range(1, MAX_SOURCE): + async def _reconnect_handler(self, res): + reconnect_delay = 0.5 + while True: try: - device_str = source_device_str(source_id) - name = await self.get_variable(device_str, "name") - if name: - subscribe_state_updates.add( - self.subscribe(self._async_handle_source, device_str) - ) - except CommandError: + self.connect_task = asyncio.create_task(self._connect_handler(res)) + await self.connect_task + except Exception as ex: + _LOGGER.error(ex) + pass + await self.do_state_update_callbacks(CallbackType.CONNECTION) + if not self._attempt_reconnection: + _LOGGER.debug( + "Failed to connect to device on initial pass, skipping reconnect." + ) break + reconnect_delay = min(reconnect_delay * 2, 30) + _LOGGER.debug( + f"Attempting reconnection to Russound device in {reconnect_delay} seconds..." + ) + await asyncio.sleep(reconnect_delay) - for controller_id, controller in self.controllers.items(): - for zone_id in range(1, get_max_zones(controller.controller_type) + 1): + async def _connect_handler(self, res): + handler_tasks = set() + try: + self._do_state_update = False + await self.connection_handler.connect() + handler_tasks.add( + asyncio.create_task(self.consumer_handler(self.connection_handler)) + ) + self.rio_version = await self.request("VERSION") + if not is_fw_version_higher(self.rio_version, MINIMUM_API_SUPPORT): + raise UnsupportedFeatureError( + f"Russound RIO API v{self.rio_version} is not supported. The minimum " + f"supported version is v{MINIMUM_API_SUPPORT}" + ) + _LOGGER.info("Connected (Russound RIO v%s})", self.rio_version) + # Fetch parent controller + parent_controller = await self._load_controller(1) + if not parent_controller: + raise RussoundError("No primary controller found.") + + self.controllers[1] = parent_controller + + # Only search for daisy-chained controllers if the parent supports RNET + if is_rnet_capable(parent_controller.controller_type): + for controller_id in range(2, MAX_RNET_CONTROLLERS + 1): + controller = await self._load_controller(controller_id) + if controller: + self.controllers[controller_id] = controller + + subscribe_state_updates = { + self.subscribe(self._async_handle_system, "System") + } + + # Load source structure + for source_id in range(1, MAX_SOURCE): try: - device_str = zone_device_str(controller_id, zone_id) + device_str = source_device_str(source_id) name = await self.get_variable(device_str, "name") if name: subscribe_state_updates.add( - self.subscribe(self._async_handle_zone, device_str) + self.subscribe(self._async_handle_source, device_str) ) except CommandError: break - subscribe_tasks = set() - for state_update in subscribe_state_updates: - subscribe_tasks.add(asyncio.create_task(state_update)) - await asyncio.wait(subscribe_tasks) - - # Delay to ensure async TTL - await asyncio.sleep(0.2) - - res.set_result(True) + for controller_id, controller in self.controllers.items(): + for zone_id in range(1, get_max_zones(controller.controller_type) + 1): + try: + device_str = zone_device_str(controller_id, zone_id) + name = await self.get_variable(device_str, "name") + if name: + subscribe_state_updates.add( + self.subscribe(self._async_handle_zone, device_str) + ) + except CommandError: + break + + subscribe_tasks = set() + for state_update in subscribe_state_updates: + subscribe_tasks.add(asyncio.create_task(state_update)) + await asyncio.wait(subscribe_tasks) + + self._do_state_update = True + await self.do_state_update_callbacks(CallbackType.CONNECTION) + + # Delay to ensure async TTL + await asyncio.sleep(0.2) + self._attempt_reconnection = True + if not res.done(): + res.set_result(True) + handler_tasks.add(asyncio.create_task(self._keep_alive())) + await asyncio.wait(handler_tasks, return_when=asyncio.FIRST_COMPLETED) + except Exception as ex: + if not res.done(): + res.set_exception(ex) + _LOGGER.error(ex, exc_info=True) + finally: + for task in handler_tasks: + if not task.done(): + task.cancel() + + while not self._futures.empty(): + future = await self._futures.get() + future.cancel() + + self._do_state_update = False + + closeout = set() + closeout.update(handler_tasks) + + if closeout: + closeout_task = asyncio.create_task(asyncio.wait(closeout)) + while not closeout_task.done(): + try: + await asyncio.shield(closeout_task) + except asyncio.CancelledError: + pass + + @staticmethod + def process_response(res: bytes) -> Optional[RussoundMessage]: + """Process an incoming string of bytes into a RussoundMessage""" + try: + # Attempt to decode in Latin and re-encode in UTF-8 to support international characters + str_res = ( + res.decode(encoding="iso-8859-1") + .encode(encoding="utf-8") + .decode(encoding="utf-8") + .strip() + ) + except UnicodeDecodeError as e: + _LOGGER.warning("Failed to decode Russound response %s", res, e) + return None + if not str_res: + return None + if len(str_res) == 1 and str_res[0] == "S": + return RussoundMessage(MessageType.STATE, None, None, None) + tag, payload = str_res[0], str_res[2:] + if tag == "E": + _LOGGER.debug("Device responded with error: %s", payload) + return RussoundMessage(tag, None, None, payload) + m = RESPONSE_REGEX.match(payload.strip()) + if not m: + return RussoundMessage(tag, None, None, None) + return RussoundMessage(tag, m.group(1) or None, m.group(2), m.group(3)) + + async def consumer_handler(self, handler: RussoundConnectionHandler): + """Callback consumer handler.""" + try: + async for raw_msg in handler.reader: + msg = self.process_response(raw_msg) + if msg: + if msg.type == "S": + future: Future = await self._futures.get() + if not future.done(): + future.set_result(msg.value) + elif msg.type == "E": + future: Future = await self._futures.get() + if not future.done(): + future.set_exception(CommandError) + if msg.branch and msg.leaf and msg.type == "N": + map_rio_to_dict(self.state, msg.branch, msg.leaf, msg.value) + subscription = self._subscriptions.get(msg.branch) + if subscription: + await subscription() + except (asyncio.CancelledError, OSError): + pass + + async def _keep_alive(self) -> None: + while True: + await asyncio.sleep(KEEP_ALIVE_INTERVAL) + _LOGGER.debug("Sending keep alive to device") + try: + async with asyncio.timeout(TIMEOUT): + await self.request("VERSION") + except asyncio.TimeoutError: + _LOGGER.warning("Keep alive request to the Russound device timed out") + break + _LOGGER.debug("Ending keep alive task to attempt reconnection") async def subscribe(self, callback: Any, branch: str) -> None: self._subscriptions[branch] = callback @@ -215,7 +319,8 @@ async def subscribe(self, callback: Any, branch: str) -> None: async def _async_handle_system(self) -> None: """Handle async info update.""" - await self.do_state_update_callbacks() + if self._do_state_update: + await self.do_state_update_callbacks() async def _async_handle_source(self) -> None: """Handle async info update.""" @@ -223,7 +328,8 @@ async def _async_handle_source(self) -> None: source = Source.from_dict(source_data) source.client = self self.sources[source_id] = source - await self.do_state_update_callbacks() + if self._do_state_update: + await self.do_state_update_callbacks() async def _async_handle_zone(self) -> None: """Handle async info update.""" @@ -233,11 +339,8 @@ async def _async_handle_zone(self) -> None: zone.client = self zone.device_str = zone_device_str(controller_id, zone_id) self.controllers[controller_id].zones[zone_id] = zone - await self.do_state_update_callbacks() - - async def close(self) -> None: - """Disconnect from the controller.""" - await self.connection_handler.close() + if self._do_state_update: + await self.do_state_update_callbacks() async def set_variable( self, device_str: str, key: str, value: str @@ -298,10 +401,10 @@ async def send_event(self, event_name, *args) -> str: cmd = f"EVENT {self.device_str}!{event_name} {args}" return await self.client.request(cmd) - # def fetch_current_source(self) -> Source: - # """Return the current source as a source object.""" - # current_source = int(self.properties.current_source) - # return self.client.sources[current_source] + def fetch_current_source(self) -> Source: + """Return the current source as a source object.""" + current_source = int(self.current_source) + return self.client.sources[current_source] async def mute(self) -> str: """Mute the zone.""" diff --git a/aiorussound/util.py b/aiorussound/util.py index 779f59a..e950409 100644 --- a/aiorussound/util.py +++ b/aiorussound/util.py @@ -74,3 +74,26 @@ def get_max_zones(model: str) -> int: def is_rnet_capable(model: str) -> bool: """Return whether a controller is rnet capable.""" return model in ("MCA-88X", "MCA-88", "MCA-66", "MCA-C5", "MCA-C3") + + +def map_rio_to_dict(state: dict, branch: str, leaf: str, value: str) -> None: + """Maps a RIO variable to a python dictionary.""" + path = re.findall(r"\w+\[?\d*]?", branch) + current = state + for part in path: + match = re.match(r"(\w+)\[(\d+)]", part) + if match: + key, index = match.groups() + index = int(index) + if key not in current: + current[key] = {} + if index not in current[key]: + current[key][index] = {} + current = current[key][index] + else: + if part not in current: + current[part] = {} + current = current[part] + + # Set the leaf and value in the final dictionary location + current[leaf] = value diff --git a/examples/basic.py b/examples/basic.py deleted file mode 100644 index a53c848..0000000 --- a/examples/basic.py +++ /dev/null @@ -1,60 +0,0 @@ -"""Example for Russound RIO package""" - -import asyncio -from asyncio import AbstractEventLoop -import logging -import os - -# Add project directory to the search path so that version of the module -# is used for tests. -import sys - -from aiorussound.connection import RussoundTcpConnectionHandler - -sys.path.insert(1, os.path.join(os.path.dirname(__file__), "..")) - -from aiorussound import RussoundClient - -_LOGGER = logging.getLogger(__package__) - - -async def demo(loop: AbstractEventLoop, host: str) -> None: - conn_handler = RussoundTcpConnectionHandler(loop, host, 4999) - rus = RussoundClient(conn_handler) - await rus.connect() - _LOGGER.info("Supported Features:") - for flag in rus.supported_features: - _LOGGER.info(flag) - - _LOGGER.info("Finding sources") - await rus.init_sources() - for source_id, source in rus.sources.items(): - await source.watch() - _LOGGER.info("%s: %s", source_id, source.name) - - _LOGGER.info("Finding controllers") - controllers = await rus.enumerate_controllers() - - for c in controllers.values(): - _LOGGER.info("%s (%s): %s", c.controller_id, c.mac_address, c.controller_id) - - _LOGGER.info("Determining valid zones") - # Determine Zones - - for zone_id, zone in c.zones.items(): - await zone.watch() - _LOGGER.info("%s: %s", zone_id, zone.name) - - await asyncio.sleep(3.0) - for source_id, source in rus.sources.items(): - print(source.properties) - - while True: - await asyncio.sleep(1) - - -logging.basicConfig(level=logging.DEBUG) -loop = asyncio.get_event_loop() -loop.set_debug(True) -loop.run_until_complete(demo(loop, sys.argv[1])) -loop.close() diff --git a/examples/subscribe.py b/examples/subscribe.py index 4038278..94163b8 100644 --- a/examples/subscribe.py +++ b/examples/subscribe.py @@ -10,13 +10,12 @@ async def on_state_change(client: RussoundClient, callback_type: CallbackType): """Called when new information is received.""" - print(f"Callback Type: {callback_type}") - print(f"Sources: {client.sources}") + print(f"Callback Type: {callback_type} {client.is_connected()}") async def main(): """Subscribe demo entrypoint.""" - conn_handler = RussoundTcpConnectionHandler(asyncio.get_running_loop(), HOST, PORT) + conn_handler = RussoundTcpConnectionHandler(HOST, PORT) client = RussoundClient(conn_handler) await client.register_state_update_callbacks(on_state_change) @@ -32,7 +31,8 @@ async def main(): print(client.state) # Play media using the unit's front controls or Russound app - await asyncio.sleep(60) + await asyncio.sleep(20) + await client.disconnect() if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index 63524af..f026a6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "aiorussound" -version = "3.1.6" +version = "4.0.0" description = "Asyncio client for Russound RIO devices." authors = ["Noah Husby "] maintainers = ["Noah Husby "]