Skip to content

Commit

Permalink
Keep task references while running
Browse files Browse the repository at this point in the history
  • Loading branch information
balloob committed Feb 13, 2023
1 parent 9008556 commit de389a4
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 135 deletions.
4 changes: 3 additions & 1 deletion homeassistant/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,9 @@ async def _async_set_up_integrations(
hass.data[DATA_SETUP_STARTED] = {}
setup_time: dict[str, timedelta] = hass.data.setdefault(DATA_SETUP_TIME, {})

watch_task = asyncio.create_task(_async_watch_pending_setups(hass))
watch_task = hass.background_tasks.async_create_task(
_async_watch_pending_setups(hass), "watch_pending_setups"
)

domains_to_setup = _get_domains(hass, config)

Expand Down
137 changes: 88 additions & 49 deletions homeassistant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,49 @@ def __str__(self) -> str:
return self.value


class BackgroundTasks:
"""Class to manage background tasks."""

def __init__(self) -> None:
"""Initialize the background task runner."""
self._tasks: set[asyncio.Task[Any]] = set()
self._loop = asyncio.get_running_loop()
self._running = True

def async_create_task(
self,
target: Coroutine[Any, Any, _R],
name: str,
) -> asyncio.Task[_R]:
"""Create a task and add it to the set of tasks."""
if not self._running:
raise RuntimeError("BackgroundTasks is no longer running")
task = self._loop.create_task(target, name=name)
self._tasks.add(task)
task.add_done_callback(self._tasks.remove)
return task

async def async_cancel_all(self) -> None:
"""Cancel all tasks."""
self._running = False

if not self._tasks:
return

for task in self._tasks:
task.cancel()

for task in list(self._tasks):
try:
await task
except asyncio.CancelledError:
pass
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error cancelling task %s", task)

self._tasks.clear()


class HomeAssistant:
"""Root object of the Home Assistant home automation."""

Expand All @@ -276,8 +319,8 @@ def __new__(cls) -> HomeAssistant:
def __init__(self) -> None:
"""Initialize new Home Assistant object."""
self.loop = asyncio.get_running_loop()
self._pending_tasks: list[asyncio.Future[Any]] = []
self._track_task = True
self._tasks: set[asyncio.Future[Any]] = set()
self.background_tasks = BackgroundTasks()
self.bus = EventBus(self)
self.services = ServiceRegistry(self)
self.states = StateMachine(self.bus, self.loop)
Expand Down Expand Up @@ -353,12 +396,14 @@ async def async_start(self) -> None:
self.bus.async_fire(EVENT_CORE_CONFIG_UPDATE)
self.bus.async_fire(EVENT_HOMEASSISTANT_START)

try:
# Only block for EVENT_HOMEASSISTANT_START listener
self.async_stop_track_tasks()
async with self.timeout.async_timeout(TIMEOUT_EVENT_START):
await self.async_block_till_done()
except asyncio.TimeoutError:
if not self._tasks:
pending: set[asyncio.Future[Any]] | None = None
else:
_done, pending = await asyncio.wait(
self._tasks, timeout=TIMEOUT_EVENT_START
)

if pending:
_LOGGER.warning(
(
"Something is blocking Home Assistant from wrapping up the start up"
Expand Down Expand Up @@ -494,9 +539,8 @@ def async_add_hass_job(
hassjob.target = cast(Callable[..., _R], hassjob.target)
task = self.loop.run_in_executor(None, hassjob.target, *args)

# If a task is scheduled
if self._track_task:
self._pending_tasks.append(task)
self._tasks.add(task)
task.add_done_callback(self._tasks.remove)

return task

Expand All @@ -517,8 +561,8 @@ def async_create_task(self, target: Coroutine[Any, Any, _R]) -> asyncio.Task[_R]
"""
task = self.loop.create_task(target)

if self._track_task:
self._pending_tasks.append(task)
self._tasks.add(task)
task.add_done_callback(self._tasks.remove)

return task

Expand All @@ -530,21 +574,11 @@ def async_add_executor_job(
task = self.loop.run_in_executor(None, target, *args)

# If a task is scheduled
if self._track_task:
self._pending_tasks.append(task)
self._tasks.add(task)
task.add_done_callback(self._tasks.remove)

return task

@callback
def async_track_tasks(self) -> None:
"""Track tasks so you can wait for all tasks to be done."""
self._track_task = True

@callback
def async_stop_track_tasks(self) -> None:
"""Stop track tasks so you can't wait for all tasks to be done."""
self._track_task = False

@overload
@callback
def async_run_hass_job(
Expand Down Expand Up @@ -637,30 +671,27 @@ async def async_block_till_done(self) -> None:
"""Block until all pending work is done."""
# To flush out any call_soon_threadsafe
await asyncio.sleep(0)
await asyncio.sleep(0)
start_time: float | None = None

while self._pending_tasks:
pending = [task for task in self._pending_tasks if not task.done()]
self._pending_tasks.clear()
if pending:
await self._await_and_log_pending(pending)

if start_time is None:
# Avoid calling monotonic() until we know
# we may need to start logging blocked tasks.
start_time = 0
elif start_time == 0:
# If we have waited twice then we set the start
# time
start_time = monotonic()
elif monotonic() - start_time > BLOCK_LOG_TIMEOUT:
# We have waited at least three loops and new tasks
# continue to block. At this point we start
# logging all waiting tasks.
for task in pending:
_LOGGER.debug("Waiting for task: %s", task)
else:
await asyncio.sleep(0)
current_task = asyncio.current_task()

while tasks := [task for task in self._tasks if task is not current_task]:
await self._await_and_log_pending(tasks)

if start_time is None:
# Avoid calling monotonic() until we know
# we may need to start logging blocked tasks.
start_time = 0
elif start_time == 0:
# If we have waited twice then we set the start
# time
start_time = monotonic()
elif monotonic() - start_time > BLOCK_LOG_TIMEOUT:
# We have waited at least three loops and new tasks
# continue to block. At this point we start
# logging all waiting tasks.
for task in tasks:
_LOGGER.debug("Waiting for task: %s", task)

async def _await_and_log_pending(self, pending: Collection[Awaitable[Any]]) -> None:
"""Await and log tasks that take a long time."""
Expand Down Expand Up @@ -702,9 +733,12 @@ async def async_stop(self, exit_code: int = 0, *, force: bool = False) -> None:
"Stopping Home Assistant before startup has completed may fail"
)

cancel_background_tasks = asyncio.create_task(
self.background_tasks.async_cancel_all()
)

# stage 1
self.state = CoreState.stopping
self.async_track_tasks()
self.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
try:
async with self.timeout.async_timeout(STAGE_1_SHUTDOWN_TIMEOUT):
Expand Down Expand Up @@ -738,6 +772,11 @@ async def async_stop(self, exit_code: int = 0, *, force: bool = False) -> None:
# the `result()` which will cause a deadlock when shutting down the executor.
shutdown_run_callback_threadsafe(self.loop)

# Run this as part of stage 3.
if not cancel_background_tasks.done():
self._tasks.add(cancel_background_tasks)
cancel_background_tasks.add_done_callback(self._tasks.remove)

try:
async with self.timeout.async_timeout(STAGE_3_SHUTDOWN_TIMEOUT):
await self.async_block_till_done()
Expand Down
51 changes: 0 additions & 51 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import pathlib
import threading
import time
from time import monotonic
import types
from typing import Any, NoReturn
from unittest.mock import AsyncMock, Mock, patch
Expand Down Expand Up @@ -214,44 +213,6 @@ def async_create_task(coroutine):

return orig_async_create_task(coroutine)

async def async_wait_for_task_count(self, max_remaining_tasks: int = 0) -> None:
"""Block until at most max_remaining_tasks remain.
Based on HomeAssistant.async_block_till_done
"""
# To flush out any call_soon_threadsafe
await asyncio.sleep(0)
start_time: float | None = None

while len(self._pending_tasks) > max_remaining_tasks:
pending: Collection[Awaitable[Any]] = [
task for task in self._pending_tasks if not task.done()
]
self._pending_tasks.clear()
if len(pending) > max_remaining_tasks:
remaining_pending = await self._await_count_and_log_pending(
pending, max_remaining_tasks=max_remaining_tasks
)
self._pending_tasks.extend(remaining_pending)

if start_time is None:
# Avoid calling monotonic() until we know
# we may need to start logging blocked tasks.
start_time = 0
elif start_time == 0:
# If we have waited twice then we set the start
# time
start_time = monotonic()
elif monotonic() - start_time > BLOCK_LOG_TIMEOUT:
# We have waited at least three loops and new tasks
# continue to block. At this point we start
# logging all waiting tasks.
for task in pending:
_LOGGER.debug("Waiting for task: %s", task)
else:
self._pending_tasks.extend(pending)
await asyncio.sleep(0)

async def _await_count_and_log_pending(
self, pending: Collection[Awaitable[Any]], max_remaining_tasks: int = 0
) -> Collection[Awaitable[Any]]:
Expand Down Expand Up @@ -280,7 +241,6 @@ async def _await_count_and_log_pending(
hass.async_add_job = async_add_job
hass.async_add_executor_job = async_add_executor_job
hass.async_create_task = async_create_task
hass.async_wait_for_task_count = types.MethodType(async_wait_for_task_count, hass)
hass._await_count_and_log_pending = types.MethodType(
_await_count_and_log_pending, hass
)
Expand Down Expand Up @@ -321,17 +281,6 @@ async def _await_count_and_log_pending(

hass.state = CoreState.running

# Mock async_start
orig_start = hass.async_start

async def mock_async_start():
"""Start the mocking."""
# We only mock time during tests and we want to track tasks
with patch.object(hass, "async_stop_track_tasks"):
await orig_start()

hass.async_start = mock_async_start

@callback
def clear_instance(event):
"""Clear global instance."""
Expand Down
5 changes: 1 addition & 4 deletions tests/helpers/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ async def test_wait_for_trigger_variables(hass: HomeAssistant) -> None:
actions = [
{
"alias": "variables",
"variables": {"seconds": 5},
"variables": {"seconds": 0.01},
},
{
"alias": wait_alias,
Expand All @@ -839,9 +839,6 @@ async def test_wait_for_trigger_variables(hass: HomeAssistant) -> None:
assert script_obj.is_running
assert script_obj.last_action == wait_alias
hass.states.async_set("switch.test", "off")
# the script task + 2 tasks created by wait_for_trigger script step
await hass.async_wait_for_task_count(3)
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=10))
await hass.async_block_till_done()
except (AssertionError, asyncio.TimeoutError):
await script_obj.async_stop()
Expand Down
11 changes: 7 additions & 4 deletions tests/test_config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -3559,14 +3559,17 @@ async def async_step_reauth(self, data):
"""Mock Reauth."""
await asyncio.sleep(1)

mock_integration(hass, MockModule("test"))
mock_entity_platform(hass, "config_flow.test", None)

with patch.dict(
config_entries.HANDLERS, {"comp": MockFlowHandler, "test": MockFlowHandler}
):
task = asyncio.create_task(
manager.flow.async_init("test", context={"source": "reauth"})
)
await hass.async_block_till_done()
await manager.flow.async_shutdown()
await hass.async_block_till_done()
await manager.flow.async_shutdown()

with pytest.raises(asyncio.exceptions.CancelledError):
await task
with pytest.raises(asyncio.exceptions.CancelledError):
await task
Loading

0 comments on commit de389a4

Please sign in to comment.