diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index 2726aacb..cfc8bd7b 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -14,12 +14,12 @@ import time from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, cast -import async_timeout from chip.ChipDeviceCtrl import DeviceProxyWrapper from chip.clusters import Attribute, Objects as Clusters from chip.clusters.Attribute import ValueDecodeFailure from chip.clusters.ClusterObjects import ALL_ATTRIBUTES, ALL_CLUSTERS, Cluster from chip.exceptions import ChipStackError +from chip.native import PyChipError from zeroconf import IPVersion, ServiceStateChange, Zeroconf from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf @@ -62,8 +62,6 @@ DATA_KEY_NODES = "nodes" DATA_KEY_LAST_NODE_ID = "last_node_id" -DEFAULT_CALL_TIMEOUT = 300 - LOGGER = logging.getLogger(__name__) MIN_NODE_SUBSCRIPTION_CEILING = 30 MAX_NODE_SUBSCRIPTION_CEILING = 300 @@ -113,13 +111,13 @@ def __init__( self.wifi_credentials_set: bool = False self.thread_credentials_set: bool = False self.compressed_fabric_id: int | None = None - self._node_lock: dict[int, asyncio.Lock] = {} self._aiobrowser: AsyncServiceBrowser | None = None self._aiozc: AsyncZeroconf | None = None self._fallback_node_scanner_timer: asyncio.TimerHandle | None = None self._sdk_executor = ThreadPoolExecutor( max_workers=1, thread_name_prefix="SDKExecutor" ) + self._node_setup_throttle = asyncio.Semaphore(10) async def initialize(self) -> None: """Async initialize of controller.""" @@ -249,15 +247,15 @@ async def commission_with_code( attempts, MAX_COMMISSION_RETRIES, ) - success = await self._call_sdk( + result: PyChipError | None = await self._call_sdk( self.chip_controller.CommissionWithCode, setupPayload=code, nodeid=node_id, networkOnly=network_only, ) - if success: + if result and result.is_success: break - if not success and attempts >= MAX_COMMISSION_RETRIES: + if attempts >= MAX_COMMISSION_RETRIES: raise NodeCommissionFailed( f"Commission with code failed for node {node_id}." ) @@ -321,6 +319,7 @@ async def commission_on_network( # by retrying, we increase the chances of a successful commission while attempts <= MAX_COMMISSION_RETRIES: attempts += 1 + result: PyChipError | None if ip_addr is None: # regular CommissionOnNetwork if no IP address provided LOGGER.info( @@ -329,7 +328,7 @@ async def commission_on_network( attempts, MAX_COMMISSION_RETRIES, ) - success = await self._call_sdk( + result = await self._call_sdk( self.chip_controller.CommissionOnNetwork, nodeId=node_id, setupPinCode=setup_pin_code, @@ -344,15 +343,15 @@ async def commission_on_network( attempts, MAX_COMMISSION_RETRIES, ) - success = await self._call_sdk( + result = await self._call_sdk( self.chip_controller.CommissionIP, nodeid=node_id, setupPinCode=setup_pin_code, ipaddr=ip_addr, ) - if success: + if result and result.is_success: break - if not success and attempts >= MAX_COMMISSION_RETRIES: + if attempts >= MAX_COMMISSION_RETRIES: raise NodeCommissionFailed(f"Commissioning failed for node {node_id}.") await asyncio.sleep(5) @@ -503,7 +502,6 @@ async def interview_node(self, node_id: int) -> None: try: if not (node := self._nodes.get(node_id)) or not node.available: await self._resolve_node(node_id=node_id) - async with self._get_node_lock(node_id): LOGGER.info("Interviewing node: %s", node_id) read_response: Attribute.AsyncReadTransaction.ReadResponse = ( await self.chip_controller.Read( @@ -567,15 +565,14 @@ async def send_device_command( cluster_cls: Cluster = ALL_CLUSTERS[cluster_id] command_cls = getattr(cluster_cls.Commands, command_name) command = dataclass_from_dict(command_cls, payload, allow_sdk_types=True) - async with self._get_node_lock(node_id): - return await self.chip_controller.SendCommand( - nodeid=node_id, - endpoint=endpoint_id, - payload=command, - responseType=response_type, - timedRequestTimeoutMs=timed_request_timeout_ms, - interactionTimeoutMs=interaction_timeout_ms, - ) + return await self.chip_controller.SendCommand( + nodeid=node_id, + endpoint=endpoint_id, + payload=command, + responseType=response_type, + timedRequestTimeoutMs=timed_request_timeout_ms, + interactionTimeoutMs=interaction_timeout_ms, + ) @api_command(APICommand.READ_ATTRIBUTE) async def read_attribute( @@ -587,30 +584,37 @@ async def read_attribute( if (node := self._nodes.get(node_id)) is None or not node.available: raise NodeNotReady(f"Node {node_id} is not (yet) available.") endpoint_id, cluster_id, attribute_id = parse_attribute_path(attribute_path) - assert self.server.loop is not None - async with self._get_node_lock(node_id): - future = self.server.loop.create_future() - device = await self._resolve_node(node_id) - Attribute.Read( - future=future, - eventLoop=self.server.loop, - device=device.deviceProxy, - devCtrl=self.chip_controller, - attributes=[ - Attribute.AttributePath( - EndpointId=endpoint_id, - ClusterId=cluster_id, - AttributeId=attribute_id, - ) - ], - fabricFiltered=fabric_filtered, - ).raise_on_error() - result: Attribute.AsyncReadTransaction.ReadResponse = await future - read_atributes = parse_attributes_from_read_result(result.tlvAttributes) - # update cached info in node attributes - self._nodes[node_id].attributes.update(read_atributes) - self._write_node_state(node_id) - return read_atributes + device = await self._resolve_node(node_id) + # Read a list of attributes and/or events from a target node. + # This is basically a re-implementation of the chip controller's Read function + # but one that allows us to send/request custom attributes. + + if TYPE_CHECKING: + assert self.server.loop + assert self.chip_controller + + future = self.server.loop.create_future() + device = await self._resolve_node(node_id) + Attribute.Read( + future=future, + eventLoop=self.server.loop, + device=device.deviceProxy, + devCtrl=self.chip_controller, + attributes=[ + Attribute.AttributePath( + EndpointId=endpoint_id, + ClusterId=cluster_id, + AttributeId=attribute_id, + ) + ], + fabricFiltered=fabric_filtered, + ).raise_on_error() + result: Attribute.AsyncReadTransaction.ReadResponse = await future + read_atributes = parse_attributes_from_read_result(result.tlvAttributes) + # update cached info in node attributes + self._nodes[node_id].attributes.update(read_atributes) + self._write_node_state(node_id) + return read_atributes @api_command(APICommand.WRITE_ATTRIBUTE) async def write_attribute( @@ -635,11 +639,10 @@ async def write_attribute( value_type=attribute.attribute_type.Type, allow_sdk_types=True, ) - async with self._get_node_lock(node_id): - return await self.chip_controller.WriteAttribute( - nodeid=node_id, - attributes=[(endpoint_id, attribute)], - ) + return await self.chip_controller.WriteAttribute( + nodeid=node_id, + attributes=[(endpoint_id, attribute)], + ) @api_command(APICommand.REMOVE_NODE) async def remove_node(self, node_id: int) -> None: @@ -678,12 +681,14 @@ async def remove_node(self, node_id: int) -> None: return result: Clusters.OperationalCredentials.Commands.NOCResponse | None = None try: - result = await self.chip_controller.SendCommand( - nodeid=node_id, - endpoint=0, - payload=Clusters.OperationalCredentials.Commands.RemoveFabric( - fabricIndex=fabric_index, - ), + result = await self._call_sdk( + self.chip_controller.SendCommand( + nodeid=node_id, + endpoint=0, + payload=Clusters.OperationalCredentials.Commands.RemoveFabric( + fabricIndex=fabric_index, + ), + ) ) except ChipStackError as err: LOGGER.warning( @@ -722,7 +727,9 @@ async def subscribe_attribute( ) @api_command(APICommand.PING_NODE) - async def ping_node(self, node_id: int, attempts: int = 1) -> NodePingResult: + async def ping_node( + self, node_id: int, attempts: int = 1, allow_cached_ips: bool = True + ) -> NodePingResult: """Ping node on the currently known IP-adress(es).""" result: NodePingResult = {} node = self._nodes.get(node_id) @@ -756,7 +763,7 @@ async def _do_ping(ip_address: str) -> None: result[clean_ip] = await ping_ip(ip_address, timeout, attempts=attempts) ip_addresses = await self.get_node_ip_addresses( - node_id, prefer_cache=False, scoped=True + node_id, prefer_cache=False, scoped=True, allow_cache=allow_cached_ips ) tasks = [_do_ping(x) for x in ip_addresses] # TODO: replace this gather with a taskgroup once we bump our py version @@ -781,7 +788,11 @@ async def _do_ping(ip_address: str) -> None: @api_command(APICommand.GET_NODE_IP_ADRESSES) async def get_node_ip_addresses( - self, node_id: int, prefer_cache: bool = False, scoped: bool = False + self, + node_id: int, + prefer_cache: bool = False, + scoped: bool = False, + allow_cache: bool = True, ) -> list[str]: """Return the currently known (scoped) IP-adress(es).""" cached_info = self._last_known_ip_addresses.get(node_id, []) @@ -799,7 +810,7 @@ async def get_node_ip_addresses( info = AsyncServiceInfo(MDNS_TYPE_OPERATIONAL_NODE, mdns_name) if TYPE_CHECKING: assert self._aiozc is not None - if not await info.async_request(self._aiozc.zeroconf, 3000): + if not await info.async_request(self._aiozc.zeroconf, 3000) and allow_cache: node_logger.info( "Node could not be discovered on the network, returning cached IP's" ) @@ -998,30 +1009,16 @@ def resubscription_succeeded( ) ) self._last_subscription_attempt[node_id] = 0 - future = loop.create_future() - device = await self._resolve_node(node_id) - async with async_timeout.timeout(DEFAULT_CALL_TIMEOUT): - Attribute.Read( - future=future, - eventLoop=loop, - device=device.deviceProxy, - devCtrl=self.chip_controller, - attributes=[Attribute.AttributePath()], # wildcard - events=[ - Attribute.EventPath( - EndpointId=None, Cluster=None, Event=None, Urgent=1 - ) - ], - returnClusterObject=False, - subscriptionParameters=Attribute.SubscriptionParameters( - interval_floor, interval_ceiling - ), - # Use fabricfiltered as False to detect changes made by other controllers - # and to be able to provide a list of all fabrics attached to the device - fabricFiltered=False, - autoResubscribe=True, - ).raise_on_error() - sub: Attribute.SubscriptionTransaction = await future + sub: Attribute.SubscriptionTransaction = await self.chip_controller.Read( + node_id, + attributes="*", + events=[("*", 1)], + returnClusterObject=False, + reportInterval=(interval_floor, interval_ceiling), + fabricFiltered=False, + keepSubscriptions=True, + autoResubscribe=True, + ) sub.SetAttributeUpdateCallback(attribute_updated_callback) sub.SetEventUpdateCallback(event_callback) @@ -1050,24 +1047,21 @@ def _get_next_node_id(self) -> int: async def _call_sdk( self, - func: Callable[..., _T], + target: Callable[..., _T], *args: Any, - call_timeout: int = DEFAULT_CALL_TIMEOUT, **kwargs: Any, ) -> _T: """Call function on the SDK in executor and return result.""" if self.server.loop is None: raise RuntimeError("Server not started.") - # prevent a single job in the executor blocking everything with a timeout. - async with async_timeout.timeout(call_timeout): - return cast( - _T, - await self.server.loop.run_in_executor( - self._sdk_executor, - partial(func, *args, **kwargs), - ), - ) + return cast( + _T, + await self.server.loop.run_in_executor( + self._sdk_executor, + partial(target, *args, **kwargs), + ), + ) async def _setup_node(self, node_id: int) -> None: """Handle set-up of subscriptions and interview (if needed) for known/discovered node.""" @@ -1078,51 +1072,56 @@ async def _setup_node(self, node_id: int) -> None: return self._nodes_in_setup.add(node_id) try: - # Ping the node to rule out stale mdns reports and to prevent that we - # send an unreachable node to the sdk which is very slow with resolving it. - # This will also precache the ip addresses of the node for later use. - ping_result = await self.ping_node(node_id, attempts=3) - if not any(ping_result.values()): - LOGGER.warning( - "Skip set-up for node %s because it does not appear to be reachable...", - node_id, + async with self._node_setup_throttle: + # Ping the node to rule out stale mdns reports and to prevent that we + # send an unreachable node to the sdk which is very slow with resolving it. + # This will also precache the ip addresses of the node for later use. + ping_result = await self.ping_node( + node_id, attempts=3, allow_cached_ips=False ) - return - LOGGER.info("Setting-up node %s...", node_id) - # (re)interview node (only) if needed - node_data = self._nodes[node_id] - if ( - # re-interview if we dont have any node attributes (empty node) - not node_data.attributes - # re-interview if the data model schema has changed - or node_data.interview_version != DATA_MODEL_SCHEMA_VERSION - ): + if not any(ping_result.values()): + LOGGER.warning( + "Skip set-up for node %s because it does not appear to be reachable...", + node_id, + ) + return + LOGGER.info("Setting-up node %s...", node_id) + # (re)interview node (only) if needed + node_data = self._nodes[node_id] + if ( + # re-interview if we dont have any node attributes (empty node) + not node_data.attributes + # re-interview if the data model schema has changed + or node_data.interview_version != DATA_MODEL_SCHEMA_VERSION + ): + try: + await self.interview_node(node_id) + except (NodeNotResolving, NodeInterviewFailed) as err: + LOGGER.warning( + "Unable to interview Node %s: %s", + node_id, + str(err) or err.__class__.__name__, + # log full stack trace if debug logging is enabled + exc_info=err + if LOGGER.isEnabledFor(logging.DEBUG) + else None, + ) + # NOTE: the node will be picked up by mdns discovery automatically + # when it comes available again. + return + # setup subscriptions for the node try: - await self.interview_node(node_id) - except (NodeNotResolving, NodeInterviewFailed) as err: + await self._subscribe_node(node_id) + except (NodeNotResolving, ChipStackError) as err: LOGGER.warning( - "Unable to interview Node %s: %s", + "Unable to subscribe to Node %s: %s", node_id, str(err) or err.__class__.__name__, # log full stack trace if debug logging is enabled exc_info=err if LOGGER.isEnabledFor(logging.DEBUG) else None, ) # NOTE: the node will be picked up by mdns discovery automatically - # when it comes available again. - return - # setup subscriptions for the node - try: - await self._subscribe_node(node_id) - except (NodeNotResolving, TimeoutError) as err: - LOGGER.warning( - "Unable to subscribe to Node %s: %s", - node_id, - str(err) or err.__class__.__name__, - # log full stack trace if debug logging is enabled - exc_info=err if LOGGER.isEnabledFor(logging.DEBUG) else None, - ) - # NOTE: the node will be picked up by mdns discovery automatically - # when it becomes available again. + # when it becomes available again. finally: self._nodes_in_setup.discard(node_id) @@ -1147,7 +1146,7 @@ async def _resolve_node( allowPASE=False, timeoutMs=None, ) - except (ChipStackError, TimeoutError) as err: + except ChipStackError as err: if attempt >= retries: # when we're out of retries, raise NodeNotResolving raise NodeNotResolving(f"Unable to resolve Node {node_id}") from err @@ -1254,12 +1253,6 @@ async def _on_mdns_commissionable_node_state( await info.async_request(self._aiozc.zeroconf, 3000) LOGGER.debug("Discovered commissionable Matter node using MDNS: %s", info) - def _get_node_lock(self, node_id: int) -> asyncio.Lock: - """Return lock for given node.""" - if node_id not in self._node_lock: - self._node_lock[node_id] = asyncio.Lock() - return self._node_lock[node_id] - def _write_node_state(self, node_id: int, force: bool = False) -> None: """Schedule the write of the current node state to persistent storage.""" node = self._nodes[node_id] @@ -1295,7 +1288,7 @@ async def _fallback_node_scanner(self) -> None: last_seen = self._node_last_seen.get(node_id, 0) if now - last_seen < FALLBACK_NODE_SCANNER_INTERVAL: continue - if await self.ping_node(node_id, attempts=3): + if await self.ping_node(node_id, attempts=3, allow_cached_ips=False): LOGGER.info("Node %s discovered using fallback ping", node_id) await self._setup_node(node_id)