Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure network startup is consistent #140

Merged
merged 1 commit into from
Apr 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 48 additions & 48 deletions tests/test_application.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio

import pytest
from zigpy import types as t
import zigpy.exceptions
from zigpy.zdo.types import ZDOCmd
import zigpy.state
import zigpy.types as t
import zigpy.zdo.types as zdo_t

from zigpy_xbee.api import ModemStatus, XBee
import zigpy_xbee.config as config
Expand All @@ -21,6 +22,42 @@
}


@pytest.fixture
def node_info():
return zigpy.state.NodeInfo(
nwk=t.NWK(0x0000),
ieee=t.EUI64.convert("00:12:4b:00:1c:a1:b8:46"),
logical_type=zdo_t.LogicalType.Coordinator,
)


@pytest.fixture
def network_info(node_info):
return zigpy.state.NetworkInfo(
extended_pan_id=t.ExtendedPanId.convert("bd:27:0b:38:37:95:dc:87"),
pan_id=t.PanId(0x9BB0),
nwk_update_id=18,
nwk_manager_id=t.NWK(0x0000),
channel=t.uint8_t(15),
channel_mask=t.Channels.ALL_CHANNELS,
security_level=t.uint8_t(5),
network_key=zigpy.state.Key(
key=t.KeyData.convert("2ccade06b3090c310315b3d574d3c85a"),
seq=108,
tx_counter=118785,
),
tc_link_key=zigpy.state.Key(
key=t.KeyData(b"ZigBeeAlliance09"),
partner_ieee=node_info.ieee,
tx_counter=8712428,
),
key_table=[],
children=[],
nwk_addresses={},
source="[email protected]",
)


@pytest.fixture
def app(monkeypatch):
monkeypatch.setattr(application, "TIMEOUT_TX_STATUS", 0.1)
Expand Down Expand Up @@ -242,59 +279,22 @@ async def test_get_association_state(app):
assert ai is mock.sentinel.ai


async def test_form_network(app):
legacy_module = False

async def mock_at_command(cmd, *args):
if cmd == "MY":
return 0x0000
if cmd == "OI":
return 0x1234
elif cmd == "ID":
return 0x1234567812345678
elif cmd == "SL":
return 0x11223344
elif cmd == "SH":
return 0x55667788
elif cmd == "WR":
app._api.coordinator_started_event.set()
elif cmd == "CE" and legacy_module:
raise RuntimeError
return None
async def test_write_network_info(app, node_info, network_info):
app._api._queued_at = mock.AsyncMock(spec=XBee._queued_at)
app._api._at_command = mock.AsyncMock(spec=XBee._at_command)
app._api._running = mock.AsyncMock(spec=app._api._running)

app._api._at_command = mock.MagicMock(
spec=XBee._at_command, side_effect=mock_at_command
)
app._api._queued_at = mock.MagicMock(
spec=XBee._at_command, side_effect=mock_at_command
)
app._get_association_state = mock.AsyncMock(
spec=application.ControllerApplication._get_association_state,
return_value=0x00,
)

app.write_network_info = mock.MagicMock(wraps=app.write_network_info)

await app.form_network()
assert app._api._at_command.call_count >= 1
assert app._api._queued_at.call_count >= 7

network_info = app.write_network_info.mock_calls[0][2]["network_info"]

app._api._queued_at.assert_any_call("SC", 1 << (network_info.channel - 11))
app._api._queued_at.assert_any_call("KY", b"ZigBeeAlliance09")

app._api._at_command.reset_mock()
app._api._queued_at.reset_mock()
legacy_module = True
await app.form_network()
assert app._api._at_command.call_count >= 1
assert app._api._queued_at.call_count >= 7

network_info = app.write_network_info.mock_calls[0][2]["network_info"]
await app.write_network_info(network_info=network_info, node_info=node_info)

app._api._queued_at.assert_any_call("SC", 1 << (network_info.channel - 11))
app._api._queued_at.assert_any_call("KY", b"ZigBeeAlliance09")
app._api._queued_at.assert_any_call("NK", network_info.network_key.key.serialize())
app._api._queued_at.assert_any_call("ID", 0xBD270B383795DC87)


async def _test_start_network(
Expand Down Expand Up @@ -435,7 +435,7 @@ def _mock_command(
seq,
b"\xaa\x55\xbe\xef",
expect_reply=expect_reply,
**kwargs
**kwargs,
)


Expand Down Expand Up @@ -509,7 +509,7 @@ def nwk():

def test_rx_device_annce(app, ieee, nwk):
dst_ep = 0
cluster_id = ZDOCmd.Device_annce
cluster_id = zdo_t.ZDOCmd.Device_annce
device = mock.MagicMock()
device.status = device.Status.NEW
app.get_device = mock.MagicMock(return_value=device)
Expand Down
16 changes: 11 additions & 5 deletions zigpy_xbee/zigbee/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import zigpy.device
import zigpy.exceptions
import zigpy.quirks
import zigpy.state
import zigpy.types
import zigpy.util
from zigpy.zcl import foundation
Expand Down Expand Up @@ -69,13 +70,16 @@ async def start_network(self):
# Enable ZDO passthrough
await self._api._at_command("AO", 0x03)

if self.state.node_info == zigpy.state.NodeInfo():
await self.load_network_info()

enc_enabled = await self._api._at_command("EE")
enc_options = await self._api._at_command("EO")
zb_profile = await self._api._at_command("ZS")

if (
enc_enabled != 1
or enc_options != 2
or enc_options & 0b0010 != 0b0010
or zb_profile != 2
or association_state != 0
or self.state.node_info.nwk != 0x0000
Expand Down Expand Up @@ -134,16 +138,18 @@ async def reset_network_info(self) -> None:
await self._api._at_command("NR", 0)

async def write_network_info(self, *, network_info, node_info):
scan_bitmask = 1 << (network_info.channel - 11)
epid, _ = zigpy.types.uint64_t.deserialize(
network_info.extended_pan_id.serialize()
)
await self._api._queued_at("ID", epid)

await self._api._queued_at("ZS", 2)
scan_bitmask = 1 << (network_info.channel - 11)
await self._api._queued_at("SC", scan_bitmask)
await self._api._queued_at("EE", 1)
await self._api._queued_at("EO", 2)

await self._api._queued_at("EO", 0b0010)
await self._api._queued_at("NK", network_info.network_key.key.serialize())
await self._api._queued_at("KY", network_info.tc_link_key.key.serialize())

await self._api._queued_at("NJ", 0)
await self._api._queued_at("SP", CONF_CYCLIC_SLEEP_PERIOD)
await self._api._queued_at("SN", CONF_POLL_TIMEOUT)
Expand Down