From 5903d64ab175adb6468fd86035f3579268e37f95 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Fri, 21 Apr 2023 19:25:19 -0400 Subject: [PATCH] Ensure network startup is consistent --- tests/test_application.py | 96 ++++++++++++++++---------------- zigpy_xbee/zigbee/application.py | 16 ++++-- 2 files changed, 59 insertions(+), 53 deletions(-) diff --git a/tests/test_application.py b/tests/test_application.py index b98ea67..7323699 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1,9 +1,10 @@ import asyncio import pytest -from zigpy import types as t import zigpy.exceptions -from zigpy.zdo.types import ZDOCmd +import zigpy.state +import zigpy.types as t +import zigpy.zdo.types as zdo_t from zigpy_xbee.api import ModemStatus, XBee import zigpy_xbee.config as config @@ -21,6 +22,42 @@ } +@pytest.fixture +def node_info(): + return zigpy.state.NodeInfo( + nwk=t.NWK(0x0000), + ieee=t.EUI64.convert("00:12:4b:00:1c:a1:b8:46"), + logical_type=zdo_t.LogicalType.Coordinator, + ) + + +@pytest.fixture +def network_info(node_info): + return zigpy.state.NetworkInfo( + extended_pan_id=t.ExtendedPanId.convert("bd:27:0b:38:37:95:dc:87"), + pan_id=t.PanId(0x9BB0), + nwk_update_id=18, + nwk_manager_id=t.NWK(0x0000), + channel=t.uint8_t(15), + channel_mask=t.Channels.ALL_CHANNELS, + security_level=t.uint8_t(5), + network_key=zigpy.state.Key( + key=t.KeyData.convert("2ccade06b3090c310315b3d574d3c85a"), + seq=108, + tx_counter=118785, + ), + tc_link_key=zigpy.state.Key( + key=t.KeyData(b"ZigBeeAlliance09"), + partner_ieee=node_info.ieee, + tx_counter=8712428, + ), + key_table=[], + children=[], + nwk_addresses={}, + source="zigpy-xbee@0.0.0", + ) + + @pytest.fixture def app(monkeypatch): monkeypatch.setattr(application, "TIMEOUT_TX_STATUS", 0.1) @@ -242,59 +279,22 @@ async def test_get_association_state(app): assert ai is mock.sentinel.ai -async def test_form_network(app): - legacy_module = False - - async def mock_at_command(cmd, *args): - if cmd == "MY": - return 0x0000 - if cmd == "OI": - return 0x1234 - elif cmd == "ID": - return 0x1234567812345678 - elif cmd == "SL": - return 0x11223344 - elif cmd == "SH": - return 0x55667788 - elif cmd == "WR": - app._api.coordinator_started_event.set() - elif cmd == "CE" and legacy_module: - raise RuntimeError - return None +async def test_write_network_info(app, node_info, network_info): + app._api._queued_at = mock.AsyncMock(spec=XBee._queued_at) + app._api._at_command = mock.AsyncMock(spec=XBee._at_command) + app._api._running = mock.AsyncMock(spec=app._api._running) - app._api._at_command = mock.MagicMock( - spec=XBee._at_command, side_effect=mock_at_command - ) - app._api._queued_at = mock.MagicMock( - spec=XBee._at_command, side_effect=mock_at_command - ) app._get_association_state = mock.AsyncMock( spec=application.ControllerApplication._get_association_state, return_value=0x00, ) - app.write_network_info = mock.MagicMock(wraps=app.write_network_info) - - await app.form_network() - assert app._api._at_command.call_count >= 1 - assert app._api._queued_at.call_count >= 7 - - network_info = app.write_network_info.mock_calls[0][2]["network_info"] - - app._api._queued_at.assert_any_call("SC", 1 << (network_info.channel - 11)) - app._api._queued_at.assert_any_call("KY", b"ZigBeeAlliance09") - - app._api._at_command.reset_mock() - app._api._queued_at.reset_mock() - legacy_module = True - await app.form_network() - assert app._api._at_command.call_count >= 1 - assert app._api._queued_at.call_count >= 7 - - network_info = app.write_network_info.mock_calls[0][2]["network_info"] + await app.write_network_info(network_info=network_info, node_info=node_info) app._api._queued_at.assert_any_call("SC", 1 << (network_info.channel - 11)) app._api._queued_at.assert_any_call("KY", b"ZigBeeAlliance09") + app._api._queued_at.assert_any_call("NK", network_info.network_key.key.serialize()) + app._api._queued_at.assert_any_call("ID", 0xBD270B383795DC87) async def _test_start_network( @@ -435,7 +435,7 @@ def _mock_command( seq, b"\xaa\x55\xbe\xef", expect_reply=expect_reply, - **kwargs + **kwargs, ) @@ -509,7 +509,7 @@ def nwk(): def test_rx_device_annce(app, ieee, nwk): dst_ep = 0 - cluster_id = ZDOCmd.Device_annce + cluster_id = zdo_t.ZDOCmd.Device_annce device = mock.MagicMock() device.status = device.Status.NEW app.get_device = mock.MagicMock(return_value=device) diff --git a/zigpy_xbee/zigbee/application.py b/zigpy_xbee/zigbee/application.py index 05e885d..1bd5337 100644 --- a/zigpy_xbee/zigbee/application.py +++ b/zigpy_xbee/zigbee/application.py @@ -10,6 +10,7 @@ import zigpy.device import zigpy.exceptions import zigpy.quirks +import zigpy.state import zigpy.types import zigpy.util from zigpy.zcl import foundation @@ -69,13 +70,16 @@ async def start_network(self): # Enable ZDO passthrough await self._api._at_command("AO", 0x03) + if self.state.node_info == zigpy.state.NodeInfo(): + await self.load_network_info() + enc_enabled = await self._api._at_command("EE") enc_options = await self._api._at_command("EO") zb_profile = await self._api._at_command("ZS") if ( enc_enabled != 1 - or enc_options != 2 + or enc_options & 0b0010 != 0b0010 or zb_profile != 2 or association_state != 0 or self.state.node_info.nwk != 0x0000 @@ -134,16 +138,18 @@ async def reset_network_info(self) -> None: await self._api._at_command("NR", 0) async def write_network_info(self, *, network_info, node_info): - scan_bitmask = 1 << (network_info.channel - 11) + epid, _ = zigpy.types.uint64_t.deserialize( + network_info.extended_pan_id.serialize() + ) + await self._api._queued_at("ID", epid) await self._api._queued_at("ZS", 2) + scan_bitmask = 1 << (network_info.channel - 11) await self._api._queued_at("SC", scan_bitmask) await self._api._queued_at("EE", 1) - await self._api._queued_at("EO", 2) - + await self._api._queued_at("EO", 0b0010) await self._api._queued_at("NK", network_info.network_key.key.serialize()) await self._api._queued_at("KY", network_info.tc_link_key.key.serialize()) - await self._api._queued_at("NJ", 0) await self._api._queued_at("SP", CONF_CYCLIC_SLEEP_PERIOD) await self._api._queued_at("SN", CONF_POLL_TIMEOUT)