diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 26c5146d38..acef1647a0 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1065,8 +1065,13 @@ class TaskState: #: Cached hash of :attr:`~TaskState.client_key` _hash: int + # Support for weakrefs to a class with __slots__ + __weakref__: Any = None __slots__ = tuple(__annotations__) # type: ignore + # Instances not part of slots since class variable + _instances: ClassVar[weakref.WeakSet[TaskState]] = weakref.WeakSet() + def __init__(self, key: str, run_spec: object): self.key = key self._hash = hash(key) @@ -1101,6 +1106,7 @@ def __init__(self, key: str, run_spec: object): self.metadata = {} self.annotations = {} self.erred_on = set() + TaskState._instances.add(self) def __hash__(self) -> int: return self._hash diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 6172083408..a5076e7fe3 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -1,17 +1,22 @@ from __future__ import annotations import asyncio +import gc from collections.abc import Iterator import pytest +from tlz import first -from distributed import Worker, wait +import distributed.profile as profile +from distributed import Nanny, Worker, wait from distributed.protocol.serialize import Serialize +from distributed.scheduler import TaskState as SchedulerTaskState from distributed.utils import recursive_to_dict from distributed.utils_test import ( BlockedGetData, _LockedCommPool, assert_story, + clean, freeze_data_fetching, gen_cluster, inc, @@ -42,6 +47,17 @@ async def wait_for_state(key: str, state: TaskStateState, dask_worker: Worker) - await asyncio.sleep(0.005) +@clean() +def test_task_state_tracking(): + with clean(): + x = TaskState("x") + assert len(TaskState._instances) == 1 + assert first(TaskState._instances) == x + + del x + assert len(TaskState._instances) == 0 + + def test_TaskState_get_nbytes(): assert TaskState("x", nbytes=123).get_nbytes() == 123 # Default to distributed.scheduler.default-data-size @@ -675,6 +691,34 @@ async def test_missing_to_waiting(c, s, w1, w2, w3): await f1 +@gen_cluster(client=True, Worker=Nanny) +async def test_task_state_instance_are_garbage_collected(c, s, a, b): + futs = c.map(inc, range(10)) + red = c.submit(sum, futs) + f1 = c.submit(inc, red, pure=False) + f2 = c.submit(inc, red, pure=False) + + async def check(dask_worker): + while dask_worker.tasks: + await asyncio.sleep(0.01) + with profile.lock: + gc.collect() + assert not TaskState._instances + + await c.gather([f2, f1]) + del futs, red, f1, f2 + await c.run(check) + + async def check(dask_scheduler): + while dask_scheduler.tasks: + await asyncio.sleep(0.01) + with profile.lock: + gc.collect() + assert not SchedulerTaskState._instances + + await c.run_on_scheduler(check) + + @gen_cluster(client=True, nthreads=[("", 1)] * 3) async def test_fetch_to_missing_on_refresh_who_has(c, s, w1, w2, w3): """ diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 42e90de0eb..4fa11486c2 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -58,6 +58,7 @@ from distributed.node import ServerNode from distributed.proctitle import enable_proctitle_on_children from distributed.protocol import deserialize +from distributed.scheduler import TaskState as SchedulerTaskState from distributed.security import Security from distributed.utils import ( DequeHandler, @@ -72,6 +73,7 @@ ) from distributed.worker import WORKER_ANY_RUNNING, Worker from distributed.worker_state_machine import InvalidTransition +from distributed.worker_state_machine import TaskState as WorkerTaskState try: import ssl @@ -1813,9 +1815,8 @@ def check_instances(): Scheduler._instances.clear() SpecCluster._instances.clear() Worker._initialized_clients.clear() - # assert all(n.status == "closed" for n in Nanny._instances), { - # n: n.status for n in Nanny._instances - # } + SchedulerTaskState._instances.clear() + WorkerTaskState._instances.clear() Nanny._instances.clear() _global_clients.clear() Comm._instances.clear() diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 3cca984cc7..c4ee6d4c1f 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -8,6 +8,7 @@ import operator import random import sys +import weakref from collections import defaultdict, deque from collections.abc import ( Callable, @@ -262,20 +263,23 @@ class TaskState: #: True if the task is in memory or erred; False otherwise done: bool = False + _instances: ClassVar[weakref.WeakSet[TaskState]] = weakref.WeakSet() + # Support for weakrefs to a class with __slots__ __weakref__: Any = field(init=False) + def __post_init__(self): + TaskState._instances.add(self) + def __repr__(self) -> str: return f"" def __eq__(self, other: object) -> bool: - if not isinstance(other, TaskState) or other.key != self.key: - return False - # When a task transitions to forgotten and exits Worker.tasks, it should be - # immediately dereferenced. If the same task is recreated later on on the - # worker, we should not have to deal with its previous incarnation lingering. - assert other is self - return True + # A task may be forgotten and a new TaskState object with the same key may be created in + # its place later on. In the Worker state, you should never have multiple TaskState objects with + # the same key. We can't assert it here however, as this comparison is also used in WeakSets + # for instance tracking purposes. + return other is self def __hash__(self) -> int: return hash(self.key) @@ -2998,6 +3002,11 @@ def validate_state(self) -> None: if self.transition_counter_max: assert self.transition_counter < self.transition_counter_max + # Test that there aren't multiple TaskState objects with the same key in data_needed + assert len({ts.key for ts in self.data_needed}) == len(self.data_needed) + for tss in self.data_needed_per_worker.values(): + assert len({ts.key for ts in tss}) == len(tss) + class BaseWorker(abc.ABC): """Wrapper around the :class:`WorkerState` that implements instructions handling.