Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement new reconnection handler #55

Merged
merged 4 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 20 additions & 160 deletions aiorussound/connection.py
Original file line number Diff line number Diff line change
@@ -1,189 +1,49 @@
import asyncio
import logging
from abc import abstractmethod
from asyncio import AbstractEventLoop, Queue, StreamReader, StreamWriter
from typing import Any, Optional
from asyncio import StreamReader
from typing import Optional

from aiorussound import CommandError
from aiorussound.const import (
DEFAULT_PORT,
RECONNECT_DELAY,
RESPONSE_REGEX,
KEEP_ALIVE_INTERVAL,
)
from aiorussound.models import RussoundMessage, MessageType

_LOGGER = logging.getLogger(__package__)

# Maintain compat with various 3.x async changes
if hasattr(asyncio, "ensure_future"):
ensure_future = asyncio.ensure_future
else:
ensure_future = getattr(asyncio, "async")


def _process_response(res: bytes) -> Optional[RussoundMessage]:
"""Process an incoming string of bytes into a RussoundMessage"""
try:
# Attempt to decode in Latin and re-encode in UTF-8 to support international characters
str_res = (
res.decode(encoding="iso-8859-1")
.encode(encoding="utf-8")
.decode(encoding="utf-8")
.strip()
)
except UnicodeDecodeError as e:
_LOGGER.warning("Failed to decode Russound response %s", res, e)
return None
if not str_res:
return None
if len(str_res) == 1 and str_res[0] == "S":
return RussoundMessage(MessageType.STATE, None, None, None)
tag, payload = str_res[0], str_res[2:]
if tag == "E":
_LOGGER.debug("Device responded with error: %s", payload)
return RussoundMessage(tag, None, None, payload)
m = RESPONSE_REGEX.match(payload.strip())
if not m:
return RussoundMessage(tag, None, None, None)
return RussoundMessage(tag, m.group(1) or None, m.group(2), m.group(3))


class RussoundConnectionHandler:
def __init__(self, loop: AbstractEventLoop) -> None:
self._loop = loop
self._connection_started: bool = False
self.connected: bool = False
self._message_callback: list[Any] = []
self._connection_callbacks: list[Any] = []
self._cmd_queue: Queue = Queue()

@abstractmethod
async def close(self):
raise NotImplementedError
def __init__(self) -> None:
self.reader: Optional[StreamReader] = None

async def send(self, cmd: str) -> None:
"""Send a command to the Russound client."""
if not self.connected:
raise CommandError("Not connected to device.")
await self._cmd_queue.put(cmd)
pass
# if not self.connected:
# raise CommandError("Not connected to device.")

@abstractmethod
async def connect(self, reconnect=True) -> None:
async def connect(self) -> None:
raise NotImplementedError

async def _keep_alive(self) -> None:
while True:
await asyncio.sleep(KEEP_ALIVE_INTERVAL) # 15 minutes
_LOGGER.debug("Sending keep alive to device")
await self.send("VERSION")

def _set_connected(self, connected: bool):
self.connected = connected
for callback in self._connection_callbacks:
callback(connected)

def add_connection_callback(self, callback) -> None:
"""Register a callback to be called whenever the instance is connected/disconnected.
The callback will be passed one argument: connected: bool.
"""
self._connection_callbacks.append(callback)

def remove_connection_callback(self, callback) -> None:
"""Removes a previously registered callback."""
self._connection_callbacks.remove(callback)

def add_message_callback(self, callback) -> None:
"""Register a callback to be called whenever the controller sends a message.
The callback will be passed one argument: msg: str.
"""
self._message_callback.append(callback)

def remove_message_callback(self, callback) -> None:
"""Removes a previously registered callback."""
self._message_callback.remove(callback)

async def _on_msg_recv(self, msg: RussoundMessage) -> None:
for callback in self._message_callback:
await callback(msg)


class RussoundTcpConnectionHandler(RussoundConnectionHandler):
def __init__(
self, loop: AbstractEventLoop, host: str, port: int = DEFAULT_PORT
) -> None:
def __init__(self, host: str, port: int = DEFAULT_PORT) -> None:
"""Initialize the Russound object using the event loop, host and port
provided.
"""
super().__init__(loop)
super().__init__()
self.host = host
self.port = port
self._ioloop_future = None
self.writer = None

async def connect(self, reconnect=True) -> None:
self._connection_started = True
_LOGGER.info("Connecting to %s:%s", self.host, self.port)
async def connect(self) -> None:
_LOGGER.debug("Connecting to %s:%s", self.host, self.port)
reader, writer = await asyncio.open_connection(self.host, self.port)
self._ioloop_future = ensure_future(self._ioloop(reader, writer, reconnect))
self._set_connected(True)

async def close(self):
"""Disconnect from the controller."""
self._connection_started = False
_LOGGER.info("Closing connection to %s:%s", self.host, self.port)
self._ioloop_future.cancel()
try:
await self._ioloop_future
except asyncio.CancelledError:
pass
self._set_connected(False)
self.reader = reader
self.writer = writer

async def _ioloop(
self, reader: StreamReader, writer: StreamWriter, reconnect: bool
) -> None:
queue_future = ensure_future(self._cmd_queue.get())
net_future = ensure_future(reader.readline())
keep_alive_task = asyncio.create_task(self._keep_alive())

try:
_LOGGER.debug("Starting IO loop")
while True:
done, _ = await asyncio.wait(
[queue_future, net_future], return_when=asyncio.FIRST_COMPLETED
)

if net_future in done:
response = net_future.result()
msg = _process_response(response)
if msg:
await self._on_msg_recv(msg)
net_future = ensure_future(reader.readline())

if queue_future in done:
cmd = queue_future.result()
writer.write(bytearray(f"{cmd}\r", "utf-8"))
await writer.drain()
queue_future = ensure_future(self._cmd_queue.get())
except asyncio.CancelledError:
_LOGGER.debug("IO loop cancelled")
self._set_connected(False)
raise
except asyncio.TimeoutError:
_LOGGER.warning("Connection to Russound client timed out")
except ConnectionResetError:
_LOGGER.warning("Connection to Russound client reset")
except Exception:
_LOGGER.exception("Unhandled exception in IO loop")
self._set_connected(False)
raise
finally:
_LOGGER.debug("Cancelling all tasks...")
writer.close()
queue_future.cancel()
net_future.cancel()
keep_alive_task.cancel()
self._set_connected(False)
if reconnect and self._connection_started:
_LOGGER.info("Retrying connection to Russound client in 5s")
await asyncio.sleep(RECONNECT_DELAY)
await self.connect(reconnect)
async def send(self, cmd: str) -> None:
"""Send a command to the Russound client."""
await super().send(cmd)
self.writer.write(bytearray(f"{cmd}\r", "utf-8"))
await self.writer.drain()
Loading