Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep task references while running #87970

Merged
merged 15 commits into from
Feb 14, 2023
2 changes: 1 addition & 1 deletion homeassistant/components/graphite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def setup(hass: HomeAssistant, config: ConfigType) -> bool:
else:
_LOGGER.debug("No connection check for UDP possible")

GraphiteFeeder(hass, host, port, protocol, prefix)
hass.data[DOMAIN] = GraphiteFeeder(hass, host, port, protocol, prefix)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this changed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously the test relied on race conditions about when the thread was done with work. Now the test just wait for the thread to be done.

return True


Expand Down
86 changes: 34 additions & 52 deletions homeassistant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,7 @@ def __new__(cls) -> HomeAssistant:
def __init__(self) -> None:
"""Initialize new Home Assistant object."""
self.loop = asyncio.get_running_loop()
self._pending_tasks: list[asyncio.Future[Any]] = []
self._track_task = True
self._tasks: set[asyncio.Future[Any]] = set()
self.bus = EventBus(self)
self.services = ServiceRegistry(self)
self.states = StateMachine(self.bus, self.loop)
Expand Down Expand Up @@ -353,12 +352,14 @@ async def async_start(self) -> None:
self.bus.async_fire(EVENT_CORE_CONFIG_UPDATE)
self.bus.async_fire(EVENT_HOMEASSISTANT_START)

try:
# Only block for EVENT_HOMEASSISTANT_START listener
self.async_stop_track_tasks()
async with self.timeout.async_timeout(TIMEOUT_EVENT_START):
await self.async_block_till_done()
except asyncio.TimeoutError:
if not self._tasks:
pending: set[asyncio.Future[Any]] | None = None
else:
_done, pending = await asyncio.wait(
self._tasks, timeout=TIMEOUT_EVENT_START
)

if pending:
_LOGGER.warning(
(
"Something is blocking Home Assistant from wrapping up the start up"
Expand Down Expand Up @@ -494,9 +495,8 @@ def async_add_hass_job(
hassjob.target = cast(Callable[..., _R], hassjob.target)
task = self.loop.run_in_executor(None, hassjob.target, *args)

# If a task is scheduled
if self._track_task:
self._pending_tasks.append(task)
self._tasks.add(task)
task.add_done_callback(self._tasks.remove)

return task

Expand All @@ -516,9 +516,8 @@ def async_create_task(self, target: Coroutine[Any, Any, _R]) -> asyncio.Task[_R]
target: target to call.
"""
task = self.loop.create_task(target)

if self._track_task:
self._pending_tasks.append(task)
self._tasks.add(task)
task.add_done_callback(self._tasks.remove)

return task

Expand All @@ -528,23 +527,11 @@ def async_add_executor_job(
) -> asyncio.Future[_T]:
"""Add an executor job from within the event loop."""
task = self.loop.run_in_executor(None, target, *args)

# If a task is scheduled
if self._track_task:
self._pending_tasks.append(task)
self._tasks.add(task)
task.add_done_callback(self._tasks.remove)

return task

@callback
def async_track_tasks(self) -> None:
"""Track tasks so you can wait for all tasks to be done."""
self._track_task = True

@callback
def async_stop_track_tasks(self) -> None:
"""Stop track tasks so you can't wait for all tasks to be done."""
self._track_task = False

@overload
@callback
def async_run_hass_job(
Expand Down Expand Up @@ -638,29 +625,25 @@ async def async_block_till_done(self) -> None:
# To flush out any call_soon_threadsafe
await asyncio.sleep(0)
start_time: float | None = None

while self._pending_tasks:
pending = [task for task in self._pending_tasks if not task.done()]
self._pending_tasks.clear()
Comment on lines -642 to -644
Copy link
Member Author

@balloob balloob Feb 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old approach would clear all the tasks so it would only track new tasks. The new approach is to just track all tasks. Old + new. It doesn't matter because if the old one blocks, it's also blocking. It was also necessary because each task will try to remove itself from this set once it's done.

if pending:
await self._await_and_log_pending(pending)

if start_time is None:
# Avoid calling monotonic() until we know
# we may need to start logging blocked tasks.
start_time = 0
elif start_time == 0:
# If we have waited twice then we set the start
# time
start_time = monotonic()
elif monotonic() - start_time > BLOCK_LOG_TIMEOUT:
# We have waited at least three loops and new tasks
# continue to block. At this point we start
# logging all waiting tasks.
for task in pending:
_LOGGER.debug("Waiting for task: %s", task)
else:
await asyncio.sleep(0)
current_task = asyncio.current_task()

while tasks := [task for task in self._tasks if task is not current_task]:
await self._await_and_log_pending(tasks)

if start_time is None:
# Avoid calling monotonic() until we know
# we may need to start logging blocked tasks.
start_time = 0
elif start_time == 0:
# If we have waited twice then we set the start
# time
start_time = monotonic()
elif monotonic() - start_time > BLOCK_LOG_TIMEOUT:
# We have waited at least three loops and new tasks
# continue to block. At this point we start
# logging all waiting tasks.
for task in tasks:
_LOGGER.debug("Waiting for task: %s", task)

async def _await_and_log_pending(self, pending: Collection[Awaitable[Any]]) -> None:
"""Await and log tasks that take a long time."""
Expand Down Expand Up @@ -704,7 +687,6 @@ async def async_stop(self, exit_code: int = 0, *, force: bool = False) -> None:

# stage 1
self.state = CoreState.stopping
self.async_track_tasks()
self.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
try:
async with self.timeout.async_timeout(STAGE_1_SHUTDOWN_TIMEOUT):
Expand Down
90 changes: 1 addition & 89 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,7 @@

import asyncio
from collections import OrderedDict
from collections.abc import (
Awaitable,
Callable,
Collection,
Generator,
Mapping,
Sequence,
)
from collections.abc import Awaitable, Callable, Generator, Mapping, Sequence
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
import functools as ft
Expand All @@ -22,8 +15,6 @@
import pathlib
import threading
import time
from time import monotonic
import types
from typing import Any, NoReturn
from unittest.mock import AsyncMock, Mock, patch

Expand Down Expand Up @@ -51,7 +42,6 @@
STATE_ON,
)
from homeassistant.core import (
BLOCK_LOG_TIMEOUT,
CoreState,
Event,
HomeAssistant,
Expand Down Expand Up @@ -221,76 +211,9 @@ def async_create_task(coroutine):

return orig_async_create_task(coroutine)

async def async_wait_for_task_count(self, max_remaining_tasks: int = 0) -> None:
"""Block until at most max_remaining_tasks remain.

Based on HomeAssistant.async_block_till_done
"""
# To flush out any call_soon_threadsafe
await asyncio.sleep(0)
start_time: float | None = None

while len(self._pending_tasks) > max_remaining_tasks:
pending: Collection[Awaitable[Any]] = [
task for task in self._pending_tasks if not task.done()
]
self._pending_tasks.clear()
if len(pending) > max_remaining_tasks:
remaining_pending = await self._await_count_and_log_pending(
pending, max_remaining_tasks=max_remaining_tasks
)
self._pending_tasks.extend(remaining_pending)

if start_time is None:
# Avoid calling monotonic() until we know
# we may need to start logging blocked tasks.
start_time = 0
elif start_time == 0:
# If we have waited twice then we set the start
# time
start_time = monotonic()
elif monotonic() - start_time > BLOCK_LOG_TIMEOUT:
# We have waited at least three loops and new tasks
# continue to block. At this point we start
# logging all waiting tasks.
for task in pending:
_LOGGER.debug("Waiting for task: %s", task)
else:
self._pending_tasks.extend(pending)
await asyncio.sleep(0)

async def _await_count_and_log_pending(
self, pending: Collection[Awaitable[Any]], max_remaining_tasks: int = 0
) -> Collection[Awaitable[Any]]:
"""Block at most max_remaining_tasks remain and log tasks that take a long time.

Based on HomeAssistant._await_and_log_pending
"""
wait_time = 0

return_when = asyncio.ALL_COMPLETED
if max_remaining_tasks:
return_when = asyncio.FIRST_COMPLETED

while len(pending) > max_remaining_tasks:
_, pending = await asyncio.wait(
pending, timeout=BLOCK_LOG_TIMEOUT, return_when=return_when
)
if not pending or max_remaining_tasks:
return pending
wait_time += BLOCK_LOG_TIMEOUT
for task in pending:
_LOGGER.debug("Waited %s seconds for task: %s", wait_time, task)

return []

hass.async_add_job = async_add_job
hass.async_add_executor_job = async_add_executor_job
hass.async_create_task = async_create_task
hass.async_wait_for_task_count = types.MethodType(async_wait_for_task_count, hass)
hass._await_count_and_log_pending = types.MethodType(
_await_count_and_log_pending, hass
)

hass.data[loader.DATA_CUSTOM_COMPONENTS] = {}

Expand Down Expand Up @@ -328,17 +251,6 @@ async def _await_count_and_log_pending(

hass.state = CoreState.running

# Mock async_start
orig_start = hass.async_start

async def mock_async_start():
"""Start the mocking."""
# We only mock time during tests and we want to track tasks
with patch.object(hass, "async_stop_track_tasks"):
await orig_start()

hass.async_start = mock_async_start

@callback
def clear_instance(event):
"""Clear global instance."""
Expand Down
1 change: 1 addition & 0 deletions tests/components/device_sun_light_trigger/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,5 +239,6 @@ async def test_initialize_start(hass: HomeAssistant) -> None:
) as mock_activate:
hass.bus.fire(EVENT_HOMEASSISTANT_START)
await hass.async_block_till_done()
await hass.async_block_till_done()

assert len(mock_activate.mock_calls) == 1
33 changes: 23 additions & 10 deletions tests/components/graphite/test_init.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""The tests for the Graphite component."""
import asyncio
import socket
from unittest import mock
from unittest.mock import patch
Expand Down Expand Up @@ -91,9 +90,11 @@ async def test_start(hass: HomeAssistant, mock_socket, mock_time) -> None:
mock_socket.reset_mock()

await hass.async_start()
await hass.async_block_till_done()

hass.states.async_set("test.entity", STATE_ON)
await asyncio.sleep(0.1)
await hass.async_block_till_done()
hass.data[graphite.DOMAIN]._queue.join()

assert mock_socket.return_value.connect.call_count == 1
assert mock_socket.return_value.connect.call_args == mock.call(("localhost", 2003))
Expand All @@ -114,9 +115,11 @@ async def test_shutdown(hass: HomeAssistant, mock_socket, mock_time) -> None:
mock_socket.reset_mock()

await hass.async_start()
await hass.async_block_till_done()

hass.states.async_set("test.entity", STATE_ON)
await asyncio.sleep(0.1)
await hass.async_block_till_done()
hass.data[graphite.DOMAIN]._queue.join()

assert mock_socket.return_value.connect.call_count == 1
assert mock_socket.return_value.connect.call_args == mock.call(("localhost", 2003))
Expand All @@ -134,7 +137,7 @@ async def test_shutdown(hass: HomeAssistant, mock_socket, mock_time) -> None:
await hass.async_block_till_done()

hass.states.async_set("test.entity", STATE_OFF)
await asyncio.sleep(0.1)
await hass.async_block_till_done()

assert mock_socket.return_value.connect.call_count == 0
assert mock_socket.return_value.sendall.call_count == 0
Expand All @@ -156,9 +159,11 @@ async def test_report_attributes(hass: HomeAssistant, mock_socket, mock_time) ->
mock_socket.reset_mock()

await hass.async_start()
await hass.async_block_till_done()

hass.states.async_set("test.entity", STATE_ON, attrs)
await asyncio.sleep(0.1)
await hass.async_block_till_done()
hass.data[graphite.DOMAIN]._queue.join()

assert mock_socket.return_value.connect.call_count == 1
assert mock_socket.return_value.connect.call_args == mock.call(("localhost", 2003))
Expand Down Expand Up @@ -186,9 +191,11 @@ async def test_report_with_string_state(
mock_socket.reset_mock()

await hass.async_start()
await hass.async_block_till_done()

hass.states.async_set("test.entity", "above_horizon", {"foo": 1.0})
await asyncio.sleep(0.1)
await hass.async_block_till_done()
hass.data[graphite.DOMAIN]._queue.join()

assert mock_socket.return_value.connect.call_count == 1
assert mock_socket.return_value.connect.call_args == mock.call(("localhost", 2003))
Expand All @@ -203,7 +210,8 @@ async def test_report_with_string_state(
mock_socket.reset_mock()

hass.states.async_set("test.entity", "not_float")
await asyncio.sleep(0.1)
await hass.async_block_till_done()
hass.data[graphite.DOMAIN]._queue.join()

assert mock_socket.return_value.connect.call_count == 0
assert mock_socket.return_value.sendall.call_count == 0
Expand All @@ -221,13 +229,15 @@ async def test_report_with_binary_state(
mock_socket.reset_mock()

await hass.async_start()
await hass.async_block_till_done()

expected = [
"ha.test.entity.foo 1.000000 12345",
"ha.test.entity.state 1.000000 12345",
]
hass.states.async_set("test.entity", STATE_ON, {"foo": 1.0})
await asyncio.sleep(0.1)
await hass.async_block_till_done()
hass.data[graphite.DOMAIN]._queue.join()

assert mock_socket.return_value.connect.call_count == 1
assert mock_socket.return_value.connect.call_args == mock.call(("localhost", 2003))
Expand All @@ -246,7 +256,8 @@ async def test_report_with_binary_state(
"ha.test.entity.state 0.000000 12345",
]
hass.states.async_set("test.entity", STATE_OFF, {"foo": 1.0})
await asyncio.sleep(0.1)
await hass.async_block_till_done()
hass.data[graphite.DOMAIN]._queue.join()

assert mock_socket.return_value.connect.call_count == 1
assert mock_socket.return_value.connect.call_args == mock.call(("localhost", 2003))
Expand Down Expand Up @@ -282,10 +293,12 @@ async def test_send_to_graphite_errors(
mock_socket.reset_mock()

await hass.async_start()
await hass.async_block_till_done()

mock_socket.return_value.connect.side_effect = error

hass.states.async_set("test.entity", STATE_ON)
await asyncio.sleep(0.1)
await hass.async_block_till_done()
hass.data[graphite.DOMAIN]._queue.join()

assert log_text in caplog.text
Loading