From 86f18d8901d8fd9b6e2ebfa9c3926ed1d1d0e45c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 16 Jun 2024 19:54:15 -0500 Subject: [PATCH] feat: small cleanup to get device functions (#76) --- src/uiprotect/data/bootstrap.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/uiprotect/data/bootstrap.py b/src/uiprotect/data/bootstrap.py index a7f922b3..0ec14316 100644 --- a/src/uiprotect/data/bootstrap.py +++ b/src/uiprotect/data/bootstrap.py @@ -9,7 +9,7 @@ from dataclasses import dataclass from datetime import datetime from functools import cache -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any from aiohttp.client_exceptions import ServerDisconnectedError @@ -298,20 +298,20 @@ def has_media(self) -> bool: def get_device_from_mac(self, mac: str) -> ProtectAdoptableDeviceModel | None: """Retrieve a device from MAC address.""" - ref = self.mac_lookup.get(normalize_mac(mac)) - if ref is None: - return None - - devices: dict[str, ProtectModelWithId] = getattr(self, ref.model.devices_key) - return cast(ProtectAdoptableDeviceModel, devices.get(ref.id)) + return self._get_device_from_ref(self.mac_lookup.get(normalize_mac(mac))) def get_device_from_id(self, device_id: str) -> ProtectAdoptableDeviceModel | None: """Retrieve a device from device ID (without knowing model type).""" - ref = self.id_lookup.get(device_id) + return self._get_device_from_ref(self.id_lookup.get(device_id)) + + def _get_device_from_ref( + self, ref: ProtectDeviceRef | None + ) -> ProtectAdoptableDeviceModel | None: if ref is None: return None - devices: dict[str, ProtectModelWithId] = getattr(self, ref.model.devices_key) - return cast(ProtectAdoptableDeviceModel, devices.get(ref.id)) + devices_key = ref.model.devices_key + devices: dict[str, ProtectAdoptableDeviceModel] = getattr(self, devices_key) + return devices[ref.id] def process_event(self, event: Event) -> None: event_type = event.type @@ -386,9 +386,8 @@ def _process_add_packet( def _process_remove_packet( self, model_type: ModelType, packet: WSPacket ) -> WSSubscriptionMessage | None: - devices: dict[str, ProtectDeviceModel] | None = getattr( - self, model_type.devices_key, None - ) + devices_key = model_type.devices_key + devices: dict[str, ProtectDeviceModel] | None = getattr(self, devices_key, None) if devices is None: return None @@ -615,9 +614,8 @@ async def refresh_device(self, model_type: ModelType, device_id: str) -> None: if isinstance(device, NVR): self.nvr = device else: - devices: dict[str, ProtectModelWithId] = getattr( - self, model_type.devices_key - ) + devices_key = model_type.devices_key + devices: dict[str, ProtectModelWithId] = getattr(self, devices_key) devices[device.id] = device _LOGGER.debug("Successfully refresh model: %s %s", model_type, device_id)