Skip to content

Commit

Permalink
Support disabling devices (#43293)
Browse files Browse the repository at this point in the history
  • Loading branch information
emontnemery authored Nov 26, 2020
1 parent 39efbcb commit dc8364f
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 10 deletions.
3 changes: 3 additions & 0 deletions homeassistant/components/config/device_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
vol.Required("device_id"): str,
vol.Optional("area_id"): vol.Any(str, None),
vol.Optional("name_by_user"): vol.Any(str, None),
# We only allow setting disabled_by user via API.
vol.Optional("disabled_by"): vol.Any("user", None),
}
)

Expand Down Expand Up @@ -77,4 +79,5 @@ def _entry_dict(entry):
"via_device_id": entry.via_device_id,
"area_id": entry.area_id,
"name_by_user": entry.name_by_user,
"disabled_by": entry.disabled_by,
}
28 changes: 28 additions & 0 deletions homeassistant/helpers/device_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
REGISTERED_DEVICE = "registered"
DELETED_DEVICE = "deleted"

DISABLED_INTEGRATION = "integration"
DISABLED_USER = "user"


@attr.s(slots=True, frozen=True)
class DeletedDeviceEntry:
Expand Down Expand Up @@ -76,6 +79,21 @@ class DeviceEntry:
id: str = attr.ib(factory=uuid_util.random_uuid_hex)
# This value is not stored, just used to keep track of events to fire.
is_new: bool = attr.ib(default=False)
disabled_by: Optional[str] = attr.ib(
default=None,
validator=attr.validators.in_(
(
DISABLED_INTEGRATION,
DISABLED_USER,
None,
)
),
)

@property
def disabled(self) -> bool:
"""Return if entry is disabled."""
return self.disabled_by is not None


def format_mac(mac: str) -> str:
Expand Down Expand Up @@ -215,6 +233,8 @@ def async_get_or_create(
sw_version=_UNDEF,
entry_type=_UNDEF,
via_device=None,
# To disable a device if it gets created
disabled_by=_UNDEF,
):
"""Get device. Create if it doesn't exist."""
if not identifiers and not connections:
Expand Down Expand Up @@ -267,6 +287,7 @@ def async_get_or_create(
name=name,
sw_version=sw_version,
entry_type=entry_type,
disabled_by=disabled_by,
)

@callback
Expand All @@ -283,6 +304,7 @@ def async_update_device(
sw_version=_UNDEF,
via_device_id=_UNDEF,
remove_config_entry_id=_UNDEF,
disabled_by=_UNDEF,
):
"""Update properties of a device."""
return self._async_update_device(
Expand All @@ -296,6 +318,7 @@ def async_update_device(
sw_version=sw_version,
via_device_id=via_device_id,
remove_config_entry_id=remove_config_entry_id,
disabled_by=disabled_by,
)

@callback
Expand All @@ -316,6 +339,7 @@ def _async_update_device(
via_device_id=_UNDEF,
area_id=_UNDEF,
name_by_user=_UNDEF,
disabled_by=_UNDEF,
):
"""Update device attributes."""
old = self.devices[device_id]
Expand Down Expand Up @@ -362,6 +386,7 @@ def _async_update_device(
("sw_version", sw_version),
("entry_type", entry_type),
("via_device_id", via_device_id),
("disabled_by", disabled_by),
):
if value is not _UNDEF and value != getattr(old, attr_name):
changes[attr_name] = value
Expand Down Expand Up @@ -440,6 +465,8 @@ async def async_load(self):
# Introduced in 0.87
area_id=device.get("area_id"),
name_by_user=device.get("name_by_user"),
# Introduced in 0.119
disabled_by=device.get("disabled_by"),
)
# Introduced in 0.111
for device in data.get("deleted_devices", []):
Expand Down Expand Up @@ -478,6 +505,7 @@ def _data_to_save(self) -> Dict[str, List[Dict[str, Any]]]:
"via_device_id": entry.via_device_id,
"area_id": entry.area_id,
"name_by_user": entry.name_by_user,
"disabled_by": entry.disabled_by,
}
for entry in self.devices.values()
]
Expand Down
36 changes: 27 additions & 9 deletions homeassistant/helpers/entity_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@
_LOGGER = logging.getLogger(__name__)
_UNDEF = object()
DISABLED_CONFIG_ENTRY = "config_entry"
DISABLED_DEVICE = "device"
DISABLED_HASS = "hass"
DISABLED_USER = "user"
DISABLED_INTEGRATION = "integration"
DISABLED_USER = "user"

STORAGE_VERSION = 1
STORAGE_KEY = "core.entity_registry"
Expand Down Expand Up @@ -89,10 +90,11 @@ class RegistryEntry:
default=None,
validator=attr.validators.in_(
(
DISABLED_CONFIG_ENTRY,
DISABLED_DEVICE,
DISABLED_HASS,
DISABLED_USER,
DISABLED_INTEGRATION,
DISABLED_CONFIG_ENTRY,
DISABLED_USER,
None,
)
),
Expand Down Expand Up @@ -127,7 +129,7 @@ def __init__(self, hass: HomeAssistantType):
self._index: Dict[Tuple[str, str, str], str] = {}
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
self.hass.bus.async_listen(
EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_removed
EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_modified
)

@callback
Expand Down Expand Up @@ -286,18 +288,34 @@ def async_remove(self, entity_id: str) -> None:
)
self.async_schedule_save()

@callback
def async_device_removed(self, event: Event) -> None:
"""Handle the removal of a device.
async def async_device_modified(self, event: Event) -> None:
"""Handle the removal or update of a device.
Remove entities from the registry that are associated to a device when
the device is removed.
Disable entities in the registry that are associated to a device when
the device is disabled.
"""
if event.data["action"] != "remove":
if event.data["action"] == "remove":
entities = async_entries_for_device(self, event.data["device_id"])
for entity in entities:
self.async_remove(entity.entity_id)
return

if event.data["action"] != "update":
return

device_registry = await self.hass.helpers.device_registry.async_get_registry()
device = device_registry.async_get(event.data["device_id"])
if not device.disabled:
return

entities = async_entries_for_device(self, event.data["device_id"])
for entity in entities:
self.async_remove(entity.entity_id)
self.async_update_entity( # type: ignore
entity.entity_id, disabled_by=DISABLED_DEVICE
)

@callback
def async_update_entity(
Expand Down
4 changes: 4 additions & 0 deletions tests/components/config/test_device_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ async def test_list_devices(hass, client, registry):
"via_device_id": None,
"area_id": None,
"name_by_user": None,
"disabled_by": None,
},
{
"config_entries": ["1234"],
Expand All @@ -69,6 +70,7 @@ async def test_list_devices(hass, client, registry):
"via_device_id": dev1,
"area_id": None,
"name_by_user": None,
"disabled_by": None,
},
]

Expand All @@ -92,6 +94,7 @@ async def test_update_device(hass, client, registry):
"device_id": device.id,
"area_id": "12345A",
"name_by_user": "Test Friendly Name",
"disabled_by": "user",
"type": "config/device_registry/update",
}
)
Expand All @@ -101,4 +104,5 @@ async def test_update_device(hass, client, registry):
assert msg["result"]["id"] == device.id
assert msg["result"]["area_id"] == "12345A"
assert msg["result"]["name_by_user"] == "Test Friendly Name"
assert msg["result"]["disabled_by"] == "user"
assert len(registry.devices) == 1
5 changes: 5 additions & 0 deletions tests/helpers/test_device_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ async def test_loading_from_storage(hass, hass_storage):
"entry_type": "service",
"area_id": "12345A",
"name_by_user": "Test Friendly Name",
"disabled_by": "user",
}
],
"deleted_devices": [
Expand Down Expand Up @@ -180,6 +181,7 @@ async def test_loading_from_storage(hass, hass_storage):
assert entry.area_id == "12345A"
assert entry.name_by_user == "Test Friendly Name"
assert entry.entry_type == "service"
assert entry.disabled_by == "user"
assert isinstance(entry.config_entries, set)
assert isinstance(entry.connections, set)
assert isinstance(entry.identifiers, set)
Expand Down Expand Up @@ -445,6 +447,7 @@ async def test_loading_saving_data(hass, registry):
manufacturer="manufacturer",
model="light",
via_device=("hue", "0123"),
disabled_by="user",
)

orig_light2 = registry.async_get_or_create(
Expand Down Expand Up @@ -581,6 +584,7 @@ async def test_update(registry):
name_by_user="Test Friendly Name",
new_identifiers=new_identifiers,
via_device_id="98765B",
disabled_by="user",
)

assert mock_save.call_count == 1
Expand All @@ -591,6 +595,7 @@ async def test_update(registry):
assert updated_entry.name_by_user == "Test Friendly Name"
assert updated_entry.identifiers == new_identifiers
assert updated_entry.via_device_id == "98765B"
assert updated_entry.disabled_by == "user"

assert registry.async_get_device({("hue", "456")}, {}) is None
assert registry.async_get_device({("bla", "123")}, {}) is None
Expand Down
61 changes: 60 additions & 1 deletion tests/helpers/test_entity_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@

import tests.async_mock
from tests.async_mock import patch
from tests.common import MockConfigEntry, flush_store, mock_registry
from tests.common import (
MockConfigEntry,
flush_store,
mock_device_registry,
mock_registry,
)

YAML__OPEN_PATH = "homeassistant.util.yaml.loader.open"

Expand Down Expand Up @@ -677,3 +682,57 @@ async def test_async_get_device_class_lookup(hass):
("sensor", "battery"): "sensor.vacuum_battery",
},
}


async def test_remove_device_removes_entities(hass, registry):
"""Test that we remove entities tied to a device."""
device_registry = mock_device_registry(hass)
config_entry = MockConfigEntry(domain="light")

device_entry = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
connections={("mac", "12:34:56:AB:CD:EF")},
)

entry = registry.async_get_or_create(
"light",
"hue",
"5678",
config_entry=config_entry,
device_id=device_entry.id,
)

assert registry.async_is_registered(entry.entity_id)

device_registry.async_remove_device(device_entry.id)
await hass.async_block_till_done()

assert not registry.async_is_registered(entry.entity_id)


async def test_disable_device_disables_entities(hass, registry):
"""Test that we remove entities tied to a device."""
device_registry = mock_device_registry(hass)
config_entry = MockConfigEntry(domain="light")

device_entry = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
connections={("mac", "12:34:56:AB:CD:EF")},
)

entry = registry.async_get_or_create(
"light",
"hue",
"5678",
config_entry=config_entry,
device_id=device_entry.id,
)

assert not entry.disabled

device_registry.async_update_device(device_entry.id, disabled_by="user")
await hass.async_block_till_done()

entry = registry.async_get(entry.entity_id)
assert entry.disabled
assert entry.disabled_by == "device"

0 comments on commit dc8364f

Please sign in to comment.