Skip to content

Commit

Permalink
Add HassDict implementation (#103844)
Browse files Browse the repository at this point in the history
  • Loading branch information
cdce8p authored May 7, 2024
1 parent fd52588 commit 3d700e2
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 36 deletions.
4 changes: 1 addition & 3 deletions homeassistant/config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2035,9 +2035,7 @@ async def async_wait_component(self, entry: ConfigEntry) -> bool:
Config entries which are created after Home Assistant is started can't be waited
for, the function will just return if the config entry is loaded or not.
"""
setup_done: dict[str, asyncio.Future[bool]] = self.hass.data.get(
DATA_SETUP_DONE, {}
)
setup_done = self.hass.data.get(DATA_SETUP_DONE, {})
if setup_future := setup_done.get(entry.domain):
await setup_future
# The component was not loaded.
Expand Down
3 changes: 2 additions & 1 deletion homeassistant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
)
from .util.event_type import EventType
from .util.executor import InterruptibleThreadPoolExecutor
from .util.hass_dict import HassDict
from .util.json import JsonObjectType
from .util.read_only_dict import ReadOnlyDict
from .util.timeout import TimeoutManager
Expand Down Expand Up @@ -406,7 +407,7 @@ def __init__(self, config_dir: str) -> None:
from . import loader

# This is a dictionary that any component can store any data on.
self.data: dict[str, Any] = {}
self.data = HassDict()
self.loop = asyncio.get_running_loop()
self._tasks: set[asyncio.Future[Any]] = set()
self._background_tasks: set[asyncio.Future[Any]] = set()
Expand Down
13 changes: 11 additions & 2 deletions homeassistant/helpers/singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,26 @@
import asyncio
from collections.abc import Callable
import functools
from typing import Any, TypeVar, cast
from typing import Any, TypeVar, cast, overload

from homeassistant.core import HomeAssistant
from homeassistant.loader import bind_hass
from homeassistant.util.hass_dict import HassKey

_T = TypeVar("_T")

_FuncType = Callable[[HomeAssistant], _T]


def singleton(data_key: str) -> Callable[[_FuncType[_T]], _FuncType[_T]]:
@overload
def singleton(data_key: HassKey[_T]) -> Callable[[_FuncType[_T]], _FuncType[_T]]: ...


@overload
def singleton(data_key: str) -> Callable[[_FuncType[_T]], _FuncType[_T]]: ...


def singleton(data_key: Any) -> Callable[[_FuncType[_T]], _FuncType[_T]]:
"""Decorate a function that should be called once per instance.
Result will be cached and simultaneous calls will be handled.
Expand Down
42 changes: 19 additions & 23 deletions homeassistant/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from .helpers.issue_registry import IssueSeverity, async_create_issue
from .helpers.typing import ConfigType
from .util.async_ import create_eager_task
from .util.hass_dict import HassKey

current_setup_group: contextvars.ContextVar[tuple[str, str | None] | None] = (
contextvars.ContextVar("current_setup_group", default=None)
Expand All @@ -45,29 +46,32 @@

BASE_PLATFORMS = {platform.value for platform in Platform}

# DATA_SETUP is a dict[str, asyncio.Future[bool]], indicating domains which are currently
# DATA_SETUP is a dict, indicating domains which are currently
# being setup or which failed to setup:
# - Tasks are added to DATA_SETUP by `async_setup_component`, the key is the domain
# being setup and the Task is the `_async_setup_component` helper.
# - Tasks are removed from DATA_SETUP if setup was successful, that is,
# the task returned True.
DATA_SETUP = "setup_tasks"
DATA_SETUP: HassKey[dict[str, asyncio.Future[bool]]] = HassKey("setup_tasks")

# DATA_SETUP_DONE is a dict [str, asyncio.Future[bool]], indicating components which
# will be setup:
# DATA_SETUP_DONE is a dict, indicating components which will be setup:
# - Events are added to DATA_SETUP_DONE during bootstrap by
# async_set_domains_to_be_loaded, the key is the domain which will be loaded.
# - Events are set and removed from DATA_SETUP_DONE when async_setup_component
# is finished, regardless of if the setup was successful or not.
DATA_SETUP_DONE = "setup_done"
DATA_SETUP_DONE: HassKey[dict[str, asyncio.Future[bool]]] = HassKey("setup_done")

# DATA_SETUP_STARTED is a dict [tuple[str, str | None], float], indicating when an attempt
# DATA_SETUP_STARTED is a dict, indicating when an attempt
# to setup a component started.
DATA_SETUP_STARTED = "setup_started"
DATA_SETUP_STARTED: HassKey[dict[tuple[str, str | None], float]] = HassKey(
"setup_started"
)

# DATA_SETUP_TIME is a defaultdict[str, defaultdict[str | None, defaultdict[SetupPhases, float]]]
# indicating how time was spent setting up a component and each group (config entry).
DATA_SETUP_TIME = "setup_time"
# DATA_SETUP_TIME is a defaultdict, indicating how time was spent
# setting up a component.
DATA_SETUP_TIME: HassKey[
defaultdict[str, defaultdict[str | None, defaultdict[SetupPhases, float]]]
] = HassKey("setup_time")

DATA_DEPS_REQS = "deps_reqs_processed"

Expand Down Expand Up @@ -126,9 +130,7 @@ def async_set_domains_to_be_loaded(hass: core.HomeAssistant, domains: set[str])
- Properly handle after_dependencies.
- Keep track of domains which will load but have not yet finished loading
"""
setup_done_futures: dict[str, asyncio.Future[bool]] = hass.data.setdefault(
DATA_SETUP_DONE, {}
)
setup_done_futures = hass.data.setdefault(DATA_SETUP_DONE, {})
setup_done_futures.update({domain: hass.loop.create_future() for domain in domains})


Expand All @@ -149,12 +151,8 @@ async def async_setup_component(
if domain in hass.config.components:
return True

setup_futures: dict[str, asyncio.Future[bool]] = hass.data.setdefault(
DATA_SETUP, {}
)
setup_done_futures: dict[str, asyncio.Future[bool]] = hass.data.setdefault(
DATA_SETUP_DONE, {}
)
setup_futures = hass.data.setdefault(DATA_SETUP, {})
setup_done_futures = hass.data.setdefault(DATA_SETUP_DONE, {})

if existing_setup_future := setup_futures.get(domain):
return await existing_setup_future
Expand Down Expand Up @@ -195,9 +193,7 @@ async def _async_process_dependencies(
Returns a list of dependencies which failed to set up.
"""
setup_futures: dict[str, asyncio.Future[bool]] = hass.data.setdefault(
DATA_SETUP, {}
)
setup_futures = hass.data.setdefault(DATA_SETUP, {})

dependencies_tasks = {
dep: setup_futures.get(dep)
Expand All @@ -210,7 +206,7 @@ async def _async_process_dependencies(
}

after_dependencies_tasks: dict[str, asyncio.Future[bool]] = {}
to_be_loaded: dict[str, asyncio.Future[bool]] = hass.data.get(DATA_SETUP_DONE, {})
to_be_loaded = hass.data.get(DATA_SETUP_DONE, {})
for dep in integration.after_dependencies:
if (
dep not in dependencies_tasks
Expand Down
31 changes: 31 additions & 0 deletions homeassistant/util/hass_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Implementation for HassDict and custom HassKey types.
Custom for type checking. See stub file.
"""

from __future__ import annotations

from typing import Generic, TypeVar

_T = TypeVar("_T")


class HassKey(str, Generic[_T]):
"""Generic Hass key type.
At runtime this is a generic subclass of str.
"""

__slots__ = ()


class HassEntryKey(str, Generic[_T]):
"""Key type for integrations with config entries.
At runtime this is a generic subclass of str.
"""

__slots__ = ()


HassDict = dict
176 changes: 176 additions & 0 deletions homeassistant/util/hass_dict.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""Stub file for hass_dict. Provide overload for type checking."""
# ruff: noqa: PYI021 # Allow docstrings

from typing import Any, Generic, TypeVar, assert_type, overload

__all__ = [
"HassDict",
"HassEntryKey",
"HassKey",
]

_T = TypeVar("_T")
_U = TypeVar("_U")

class _Key(Generic[_T]):
"""Base class for Hass key types. At runtime delegated to str."""

def __init__(self, value: str, /) -> None: ...
def __len__(self) -> int: ...
def __hash__(self) -> int: ...
def __eq__(self, other: object) -> bool: ...
def __getitem__(self, index: int) -> str: ...

class HassEntryKey(_Key[_T]):
"""Key type for integrations with config entries."""

class HassKey(_Key[_T]):
"""Generic Hass key type."""

class HassDict(dict[_Key[Any] | str, Any]):
"""Custom dict type to provide better value type hints for Hass key types."""

@overload # type: ignore[override]
def __getitem__(self, key: HassEntryKey[_T], /) -> dict[str, _T]: ...
@overload
def __getitem__(self, key: HassKey[_T], /) -> _T: ...
@overload
def __getitem__(self, key: str, /) -> Any: ...

# ------
@overload # type: ignore[override]
def __setitem__(self, key: HassEntryKey[_T], value: dict[str, _T], /) -> None: ...
@overload
def __setitem__(self, key: HassKey[_T], value: _T, /) -> None: ...
@overload
def __setitem__(self, key: str, value: Any, /) -> None: ...

# ------
@overload # type: ignore[override]
def setdefault(
self, key: HassEntryKey[_T], default: dict[str, _T], /
) -> dict[str, _T]: ...
@overload
def setdefault(self, key: HassKey[_T], default: _T, /) -> _T: ...
@overload
def setdefault(self, key: str, default: None = None, /) -> Any | None: ...
@overload
def setdefault(self, key: str, default: Any, /) -> Any: ...

# ------
@overload # type: ignore[override]
def get(self, key: HassEntryKey[_T], /) -> dict[str, _T] | None: ...
@overload
def get(self, key: HassEntryKey[_T], default: _U, /) -> dict[str, _T] | _U: ...
@overload
def get(self, key: HassKey[_T], /) -> _T | None: ...
@overload
def get(self, key: HassKey[_T], default: _U, /) -> _T | _U: ...
@overload
def get(self, key: str, /) -> Any | None: ...
@overload
def get(self, key: str, default: Any, /) -> Any: ...

# ------
@overload # type: ignore[override]
def pop(self, key: HassEntryKey[_T], /) -> dict[str, _T]: ...
@overload
def pop(
self, key: HassEntryKey[_T], default: dict[str, _T], /
) -> dict[str, _T]: ...
@overload
def pop(self, key: HassEntryKey[_T], default: _U, /) -> dict[str, _T] | _U: ...
@overload
def pop(self, key: HassKey[_T], /) -> _T: ...
@overload
def pop(self, key: HassKey[_T], default: _T, /) -> _T: ...
@overload
def pop(self, key: HassKey[_T], default: _U, /) -> _T | _U: ...
@overload
def pop(self, key: str, /) -> Any: ...
@overload
def pop(self, key: str, default: _U, /) -> Any | _U: ...

def _test_hass_dict_typing() -> None: # noqa: PYI048
"""Test HassDict overloads work as intended.
This is tested during the mypy run. Do not move it to 'tests'!
"""
d = HassDict()
entry_key = HassEntryKey[int]("entry_key")
key = HassKey[int]("key")
key2 = HassKey[dict[int, bool]]("key2")
key3 = HassKey[set[str]]("key3")
other_key = "domain"

# __getitem__
assert_type(d[entry_key], dict[str, int])
assert_type(d[entry_key]["entry_id"], int)
assert_type(d[key], int)
assert_type(d[key2], dict[int, bool])

# __setitem__
d[entry_key] = {}
d[entry_key] = 2 # type: ignore[call-overload]
d[entry_key]["entry_id"] = 2
d[entry_key]["entry_id"] = "Hello World" # type: ignore[assignment]
d[key] = 2
d[key] = "Hello World" # type: ignore[misc]
d[key] = {} # type: ignore[misc]
d[key2] = {}
d[key2] = 2 # type: ignore[misc]
d[key3] = set()
d[key3] = 2 # type: ignore[misc]
d[other_key] = 2
d[other_key] = "Hello World"

# get
assert_type(d.get(entry_key), dict[str, int] | None)
assert_type(d.get(entry_key, True), dict[str, int] | bool)
assert_type(d.get(key), int | None)
assert_type(d.get(key, True), int | bool)
assert_type(d.get(key2), dict[int, bool] | None)
assert_type(d.get(key2, {}), dict[int, bool])
assert_type(d.get(key3), set[str] | None)
assert_type(d.get(key3, set()), set[str])
assert_type(d.get(other_key), Any | None)
assert_type(d.get(other_key, True), Any)
assert_type(d.get(other_key, {})["id"], Any)

# setdefault
assert_type(d.setdefault(entry_key, {}), dict[str, int])
assert_type(d.setdefault(entry_key, {})["entry_id"], int)
assert_type(d.setdefault(key, 2), int)
assert_type(d.setdefault(key2, {}), dict[int, bool])
assert_type(d.setdefault(key2, {})[2], bool)
assert_type(d.setdefault(key3, set()), set[str])
assert_type(d.setdefault(other_key, 2), Any)
assert_type(d.setdefault(other_key), Any | None)
d.setdefault(entry_key, {})["entry_id"] = 2
d.setdefault(entry_key, {})["entry_id"] = "Hello World" # type: ignore[assignment]
d.setdefault(key, 2)
d.setdefault(key, "Error") # type: ignore[misc]
d.setdefault(key2, {})[2] = True
d.setdefault(key2, {})[2] = "Error" # type: ignore[assignment]
d.setdefault(key3, set()).add("Hello World")
d.setdefault(key3, set()).add(2) # type: ignore[arg-type]
d.setdefault(other_key, {})["id"] = 2
d.setdefault(other_key, {})["id"] = "Hello World"
d.setdefault(entry_key) # type: ignore[call-overload]
d.setdefault(key) # type: ignore[call-overload]
d.setdefault(key2) # type: ignore[call-overload]

# pop
assert_type(d.pop(entry_key), dict[str, int])
assert_type(d.pop(entry_key, {}), dict[str, int])
assert_type(d.pop(entry_key, 2), dict[str, int] | int)
assert_type(d.pop(key), int)
assert_type(d.pop(key, 2), int)
assert_type(d.pop(key, "Hello World"), int | str)
assert_type(d.pop(key2), dict[int, bool])
assert_type(d.pop(key2, {}), dict[int, bool])
assert_type(d.pop(key2, 2), dict[int, bool] | int)
assert_type(d.pop(key3), set[str])
assert_type(d.pop(key3, set()), set[str])
assert_type(d.pop(other_key), Any)
assert_type(d.pop(other_key, True), Any | bool)
Loading

0 comments on commit 3d700e2

Please sign in to comment.