Skip to content

Commit

Permalink
Use zigpy flow control
Browse files Browse the repository at this point in the history
  • Loading branch information
puddly committed Oct 25, 2024
1 parent 06e3054 commit fa6e64d
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 57 deletions.
9 changes: 0 additions & 9 deletions tests/test_uart.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,3 @@ async def test_connection_lost(dummy_serial_conn, mocker, event_loop):

# Losing a connection propagates up to the ZNP object
assert (await conn_lost_fut) == exception


async def test_connection_made(dummy_serial_conn, mocker):
device, _ = dummy_serial_conn
znp = mocker.Mock()

await znp_uart.connect(conf.SCHEMA_DEVICE({conf.CONF_DEVICE_PATH: device}), api=znp)

znp.connection_made.assert_called_once_with()
15 changes: 9 additions & 6 deletions zigpy_znp/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,16 +748,11 @@ async def connect(self, *, test_port=True) -> None:
LOGGER.debug("Detected Z-Stack %s", self.version)
except (Exception, asyncio.CancelledError):
LOGGER.debug("Connection to %s failed, cleaning up", self._port_path)
self.close()
await self.disconnect()
raise

LOGGER.debug("Connected to %s", self._uart.url)

def connection_made(self) -> None:
"""
Called by the UART object when a connection has been made.
"""

def connection_lost(self, exc) -> None:
"""
Called by the UART object to indicate that the port was closed. Propagates up
Expand Down Expand Up @@ -786,8 +781,16 @@ def close(self) -> None:
self.version = None
self.capabilities = None

async def disconnect(self) -> None:
"""
Disconnects from the ZNP device.
"""

self.close()

if self._uart is not None:
self._uart.close()
await self._uart.wait_until_closed()
self._uart = None

def remove_listener(self, listener: BaseResponseListener) -> None:
Expand Down
50 changes: 9 additions & 41 deletions zigpy_znp/uart.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,50 +20,27 @@ class BufferTooShort(Exception):
pass


class ZnpMtProtocol(asyncio.Protocol):
class ZnpMtProtocol(zigpy.serial.SerialProtocol):
def __init__(self, api, *, url: str | None = None) -> None:
self._buffer = bytearray()
super().__init__()
self._api = api
self._transport = None
self._connected_event = asyncio.Event()

self.url = url

def close(self) -> None:
"""Closes the port."""

self._api = None
self._buffer.clear()

if self._transport is not None:
LOGGER.debug("Closing serial port")

self._transport.close()
self._transport = None
super().close()

def connection_lost(self, exc: Exception | None) -> None:
"""Connection lost."""

if exc is not None:
LOGGER.warning("Lost connection", exc_info=exc)
super().connection_lost(exc)

if self._api is not None:
self._api.connection_lost(exc)

def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Opened serial port."""
self._transport = transport
LOGGER.debug("Opened %s serial port", self.url)

self._connected_event.set()

if self._api is not None:
self._api.connection_made()

def data_received(self, data: bytes) -> None:
"""Callback when data is received."""
self._buffer += data

super().data_received(data)
LOGGER.log(log.TRACE, "Received data: %s", Bytes.__repr__(data))

for frame in self._extract_frames():
Expand Down Expand Up @@ -160,25 +137,16 @@ def __repr__(self) -> str:


async def connect(config: conf.ConfigType, api) -> ZnpMtProtocol:
loop = asyncio.get_running_loop()

port = config[zigpy.config.CONF_DEVICE_PATH]
baudrate = config[zigpy.config.CONF_DEVICE_BAUDRATE]
flow_control = config[zigpy.config.CONF_DEVICE_FLOW_CONTROL]

LOGGER.debug("Connecting to %s at %s baud", port, baudrate)

_, protocol = await zigpy.serial.create_serial_connection(
loop=loop,
loop=asyncio.get_running_loop(),
protocol_factory=lambda: ZnpMtProtocol(api, url=port),
url=port,
baudrate=baudrate,
xonxoff=(flow_control == "software"),
rtscts=(flow_control == "hardware"),
baudrate=config[zigpy.config.CONF_DEVICE_BAUDRATE],
flow_control=config[zigpy.config.CONF_DEVICE_FLOW_CONTROL],
)

await protocol._connected_event.wait()

LOGGER.debug("Connected to %s at %s baud", port, baudrate)
await protocol.wait_until_connected()

return protocol
2 changes: 1 addition & 1 deletion zigpy_znp/zigbee/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ async def disconnect(self):
except Exception as e:
LOGGER.warning("Failed to reset before disconnect: %s", e)
finally:
self._znp.close()
await self._znp.disconnect()
self._znp = None

async def add_endpoint(self, descriptor: zdo_t.SimpleDescriptor) -> None:
Expand Down

0 comments on commit fa6e64d

Please sign in to comment.