diff --git a/homeassistant/components/zha/helpers.py b/homeassistant/components/zha/helpers.py index 0689929699164..f24f6a34a8c22 100644 --- a/homeassistant/components/zha/helpers.py +++ b/homeassistant/components/zha/helpers.py @@ -104,7 +104,7 @@ ATTR_NAME, Platform, ) -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import Event, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import ( config_validation as cv, @@ -495,7 +495,7 @@ def __init__( self.hass = hass self.config_entry = config_entry self.gateway = gateway - self.device_proxies: dict[str, ZHADeviceProxy] = {} + self.device_proxies: dict[EUI64, ZHADeviceProxy] = {} self.group_proxies: dict[int, ZHAGroupProxy] = {} self._ha_entity_refs: collections.defaultdict[EUI64, list[EntityReference]] = ( collections.defaultdict(list) @@ -509,6 +509,12 @@ def __init__( self._unsubs: list[Callable[[], None]] = [] self._unsubs.append(self.gateway.on_all_events(self._handle_event_protocol)) self._reload_task: asyncio.Task | None = None + config_entry.async_on_unload( + self.hass.bus.async_listen( + er.EVENT_ENTITY_REGISTRY_UPDATED, + self._handle_entity_registry_updated, + ) + ) @property def ha_entity_refs(self) -> collections.defaultdict[EUI64, list[EntityReference]]: @@ -532,6 +538,46 @@ def register_entity_reference( ) ) + async def _handle_entity_registry_updated( + self, event: Event[er.EventEntityRegistryUpdatedData] + ) -> None: + """Handle when entity registry updated.""" + entity_id = event.data["entity_id"] + entity_entry: er.RegistryEntry | None = er.async_get(self.hass).async_get( + entity_id + ) + if ( + entity_entry is None + or entity_entry.config_entry_id != self.config_entry.entry_id + or entity_entry.device_id is None + ): + return + device_entry: dr.DeviceEntry | None = dr.async_get(self.hass).async_get( + entity_entry.device_id + ) + assert device_entry + + ieee_address = next( + identifier + for domain, identifier in device_entry.identifiers + if domain == DOMAIN + ) + assert ieee_address + + ieee = EUI64.convert(ieee_address) + + assert ieee in self.device_proxies + + zha_device_proxy = self.device_proxies[ieee] + entity_key = (entity_entry.domain, entity_entry.unique_id) + if entity_key not in zha_device_proxy.device.platform_entities: + return + platform_entity = zha_device_proxy.device.platform_entities[entity_key] + if entity_entry.disabled: + platform_entity.disable() + else: + platform_entity.enable() + async def async_initialize_devices_and_entities(self) -> None: """Initialize devices and entities.""" for device in self.gateway.devices.values(): @@ -1117,7 +1163,7 @@ def async_add_entities( if not entities: return - entities_to_add = [] + entities_to_add: list[ZHAEntity] = [] for entity_data in entities: try: entities_to_add.append(entity_class(entity_data)) @@ -1129,6 +1175,9 @@ def async_add_entities( "Error while adding entity from entity data: %s", entity_data ) _async_add_entities(entities_to_add, update_before_add=False) + for entity in entities_to_add: + if not entity.enabled: + entity.entity_data.entity.disable() entities.clear() diff --git a/tests/components/zha/test_binary_sensor.py b/tests/components/zha/test_binary_sensor.py index 419823b3b5202..a9765a1b54779 100644 --- a/tests/components/zha/test_binary_sensor.py +++ b/tests/components/zha/test_binary_sensor.py @@ -14,6 +14,7 @@ ) from homeassistant.const import STATE_OFF, STATE_ON, Platform from homeassistant.core import HomeAssistant +from homeassistant.helpers import entity_registry as er from .common import find_entity_id, send_attributes_report from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE @@ -37,6 +38,7 @@ def binary_sensor_platform_only(): async def test_binary_sensor( hass: HomeAssistant, + entity_registry: er.EntityRegistry, setup_zha, zigpy_device_mock, ) -> None: @@ -77,3 +79,20 @@ async def test_binary_sensor( hass, cluster, {general.OnOff.AttributeDefs.on_off.id: OFF} ) assert hass.states.get(entity_id).state == STATE_OFF + + # test enable / disable sync w/ ZHA library + entity_entry = entity_registry.async_get(entity_id) + entity_key = (Platform.BINARY_SENSOR, entity_entry.unique_id) + assert zha_device_proxy.device.platform_entities.get(entity_key).enabled + + entity_registry.async_update_entity( + entity_id=entity_id, disabled_by=er.RegistryEntryDisabler.USER + ) + await hass.async_block_till_done() + + assert not zha_device_proxy.device.platform_entities.get(entity_key).enabled + + entity_registry.async_update_entity(entity_id=entity_id, disabled_by=None) + await hass.async_block_till_done() + + assert zha_device_proxy.device.platform_entities.get(entity_key).enabled