From eb4d1418300140970a7e955a26b0099ac6fcd5f5 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Sun, 27 Oct 2024 16:15:25 -0400 Subject: [PATCH] Zigpy serial protocol (#160) * Migrate zigate to zigpy serial protocol * Fix unit tests * Let zigpy handle flow control * Bump minimum zigpy version * Remove unnecessary `close` * Clean API only on close * Fix annotations * Test `connection_lost` --- pyproject.toml | 2 +- tests/test_api.py | 11 +++-- tests/test_application.py | 12 ++--- tests/test_uart.py | 15 +++---- zigpy_zigate/api.py | 6 +-- zigpy_zigate/uart.py | 70 +++++++++--------------------- zigpy_zigate/zigbee/application.py | 2 +- 7 files changed, 46 insertions(+), 72 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 452c2e4..46749e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ license = {text = "GPL-3.0"} requires-python = ">=3.8" dependencies = [ "voluptuous", - "zigpy>=0.66.0", + "zigpy>=0.70.0", "pyusb>=1.1.0", "gpiozero", 'async-timeout; python_version<"3.11"', diff --git a/tests/test_api.py b/tests/test_api.py index 0abb37c..15177ef 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,5 +1,5 @@ import asyncio -from unittest.mock import MagicMock, patch, sentinel +from unittest.mock import AsyncMock, MagicMock, patch, sentinel import pytest import serial_asyncio @@ -37,10 +37,13 @@ async def mock_conn(loop, protocol_factory, **kwargs): await api.connect() -def test_close(api): +@pytest.mark.asyncio +async def test_disconnect(api): uart = api._uart - api.close() - assert uart.close.call_count == 1 + uart.disconnect = AsyncMock() + + await api.disconnect() + assert uart.disconnect.call_count == 1 assert api._uart is None diff --git a/tests/test_application.py b/tests/test_application.py index fbff54e..b587eb4 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -102,28 +102,28 @@ async def mock_get_network_state(): @pytest.mark.asyncio async def test_disconnect_success(app): - api = MagicMock() + api = AsyncMock() app._api = api await app.disconnect() - api.close.assert_called_once() + api.disconnect.assert_called_once() assert app._api is None @pytest.mark.asyncio async def test_disconnect_failure(app, caplog): - api = MagicMock() - api.disconnect = MagicMock(side_effect=RuntimeError("Broken")) + api = AsyncMock() + api.reset = AsyncMock(side_effect=RuntimeError("Broken")) app._api = api with caplog.at_level(logging.WARNING): await app.disconnect() - assert "disconnect" in caplog.text + assert "Failed to reset before disconnect" in caplog.text - api.close.assert_called_once() + api.disconnect.assert_called_once() assert app._api is None diff --git a/tests/test_uart.py b/tests/test_uart.py index 463692b..0fe026b 100644 --- a/tests/test_uart.py +++ b/tests/test_uart.py @@ -1,4 +1,4 @@ -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, call import gpiozero import pytest @@ -52,6 +52,12 @@ def test_close(gw): assert gw._transport.close.call_count == 1 +def test_connection_lost(gw): + exc = RuntimeError() + gw.connection_lost(exc) + assert gw._api.connection_lost.mock_calls == [call(exc)] + + def test_data_received_chunk_frame(gw): data = b"\x01\x80\x10\x02\x10\x02\x15\xaa\x02\x10\x02\x1f?\xf0\xff\x03" gw.data_received(data[:-4]) @@ -108,13 +114,6 @@ def test_escape(gw): assert r == data_escaped -def test_length(gw): - data = b"\x80\x10\x00\x05\xaa\x00\x0f?\xf0\xff" - length = 5 - r = gw._length(data) - assert r == length - - def test_checksum(gw): data = b"\x00\x0f?\xf0" checksum = 0xAA diff --git a/zigpy_zigate/api.py b/zigpy_zigate/api.py index 8e2cc17..515f8d8 100644 --- a/zigpy_zigate/api.py +++ b/zigpy_zigate/api.py @@ -246,9 +246,9 @@ def connection_lost(self, exc: Exception) -> None: if self._app is not None: self._app.connection_lost(exc) - def close(self): - if self._uart: - self._uart.close() + async def disconnect(self): + if self._uart is not None: + await self._uart.disconnect() self._uart = None def set_application(self, app): diff --git a/zigpy_zigate/uart.py b/zigpy_zigate/uart.py index 57b455a..0ecdda7 100644 --- a/zigpy_zigate/uart.py +++ b/zigpy_zigate/uart.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import asyncio import binascii import logging import struct -from typing import Any, Dict +from typing import Any import zigpy.config import zigpy.serial @@ -12,39 +14,24 @@ LOGGER = logging.getLogger(__name__) -class Gateway(asyncio.Protocol): +class Gateway(zigpy.serial.SerialProtocol): START = b"\x01" END = b"\x03" - def __init__(self, api, connected_future=None): - self._buffer = b"" - self._connected_future = connected_future + def __init__(self, api): + super().__init__() self._api = api - def connection_lost(self, exc) -> None: - """Port was closed expecteddly or unexpectedly.""" - if self._connected_future and not self._connected_future.done(): - if exc is None: - self._connected_future.set_result(True) - else: - self._connected_future.set_exception(exc) - if exc is None: - LOGGER.debug("Closed serial connection") - return - - LOGGER.error("Lost serial connection: %s", exc) - self._api.connection_lost(exc) + def connection_lost(self, exc: Exception | None) -> None: + """Port was closed expectedly or unexpectedly.""" + super().connection_lost(exc) - def connection_made(self, transport): - """Callback when the uart is connected""" - LOGGER.debug("Connection made") - self._transport = transport - if self._connected_future: - self._connected_future.set_result(True) + if self._api is not None: + self._api.connection_lost(exc) def close(self): - if self._transport: - self._transport.close() + super().close() + self._api = None def send(self, cmd, data=b""): """Send data, taking care of escaping and framing""" @@ -60,8 +47,7 @@ def send(self, cmd, data=b""): def data_received(self, data): """Callback when there is data received from the uart""" - self._buffer += data - # LOGGER.debug('data_received %s', self._buffer) + super().data_received(data) endpos = self._buffer.find(self.END) while endpos != -1: startpos = self._buffer.rfind(self.START, 0, endpos) @@ -71,7 +57,7 @@ def data_received(self, data): cmd, length, checksum, f_data, lqi = struct.unpack( "!HHB%dsB" % (len(frame) - 6), frame ) - if self._length(frame) != length: + if len(frame) - 5 != length: LOGGER.warning( "Invalid length: %s, data: %s", length, len(frame) - 6 ) @@ -126,42 +112,28 @@ def _checksum(self, *args): chcksum ^= x return chcksum - def _length(self, frame): - length = len(frame) - 5 - return length - - -async def connect(device_config: Dict[str, Any], api, loop=None): - if loop is None: - loop = asyncio.get_event_loop() - - connected_future = asyncio.Future() - protocol = Gateway(api, connected_future) +async def connect(device_config: dict[str, Any], api, loop=None): + loop = asyncio.get_running_loop() port = device_config[zigpy.config.CONF_DEVICE_PATH] - if port == "auto": - port = await loop.run_in_executor(None, c.discover_port) if await c.async_is_pizigate(port): LOGGER.debug("PiZiGate detected") await c.async_set_pizigate_running_mode() - # in case of pizigate:/dev/ttyAMA0 syntax - if port.startswith("pizigate:"): - port = port.replace("pizigate:", "", 1) + port = port.replace("pizigate:", "", 1) elif await c.async_is_zigate_din(port): LOGGER.debug("ZiGate USB DIN detected") await c.async_set_zigatedin_running_mode() - elif c.is_zigate_wifi(port): - LOGGER.debug("ZiGate WiFi detected") + protocol = Gateway(api) _, protocol = await zigpy.serial.create_serial_connection( loop, lambda: protocol, url=port, baudrate=device_config[zigpy.config.CONF_DEVICE_BAUDRATE], - xonxoff=False, + flow_control=device_config[zigpy.config.CONF_DEVICE_FLOW_CONTROL], ) - await connected_future + await protocol.wait_until_connected() return protocol diff --git a/zigpy_zigate/zigbee/application.py b/zigpy_zigate/zigbee/application.py index 9b7718c..a6a919b 100644 --- a/zigpy_zigate/zigbee/application.py +++ b/zigpy_zigate/zigbee/application.py @@ -63,7 +63,7 @@ async def disconnect(self): except Exception as e: LOGGER.warning("Failed to reset before disconnect: %s", e) finally: - self._api.close() + await self._api.disconnect() self._api = None async def start_network(self):