Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track worker_state_machine.TaskState instances #6525

Merged
merged 14 commits into from
Jun 16, 2022
56 changes: 56 additions & 0 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import asyncio
import subprocess
import sys
from collections.abc import Iterator

import pytest
from tlz import first

from distributed import Worker, wait
from distributed.protocol.serialize import Serialize
Expand All @@ -12,6 +15,7 @@
BlockedGetData,
_LockedCommPool,
assert_story,
clean,
freeze_data_fetching,
gen_cluster,
inc,
Expand Down Expand Up @@ -40,6 +44,16 @@ async def wait_for_state(key: str, state: TaskStateState, dask_worker: Worker) -
await asyncio.sleep(0.005)


@clean()
def test_task_state_tracking():
x = TaskState("x")
assert len(TaskState._instances) == 1
assert first(TaskState._instances) == x

del x
assert len(TaskState._instances) == 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should add the clean ctxmanager or just the clean_instances ctxmanager. I don't think this is applied automatically and this test would otherwise be fragile and dependent on test order



def test_TaskState_get_nbytes():
assert TaskState("x", nbytes=123).get_nbytes() == 123
# Default to distributed.scheduler.default-data-size
Expand Down Expand Up @@ -603,3 +617,45 @@ async def test_missing_to_waiting(c, s, w1, w2, w3):
await w1.close()

await f1


client_script = """
from dask.distributed import Client
from dask.distributed.worker_state_machine import TaskState


def inc(x):
return x + 1


if __name__ == "__main__":
with Client(processes=%s, n_workers=1) as client:
fjetter marked this conversation as resolved.
Show resolved Hide resolved
futs = client.map(inc, range(10))
red = client.submit(sum, futs)
f1 = client.submit(inc, red, pure=False)
f2 = client.submit(inc, red, pure=False)
f2.result()
del futs, red, f1, f2

def check():
assert not TaskState._instances, len(TaskState._instances)

client.run(check)
"""


@pytest.mark.parametrize("processes", [True, False])
def test_task_state_instance_are_garbage_collected(processes, tmp_path):
fjetter marked this conversation as resolved.
Show resolved Hide resolved
with open(tmp_path / "script.py", mode="w") as f:
f.write(client_script % processes)

proc = subprocess.Popen(
[sys.executable, tmp_path / "script.py"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)

out, err = proc.communicate()

assert not out
assert not err
fjetter marked this conversation as resolved.
Show resolved Hide resolved
5 changes: 2 additions & 3 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
sync,
)
from distributed.worker import WORKER_ANY_RUNNING, InvalidTransition, Worker
from distributed.worker_state_machine import TaskState as WorkerTaskState

try:
import ssl
Expand Down Expand Up @@ -1823,9 +1824,7 @@ def check_instances():
Scheduler._instances.clear()
SpecCluster._instances.clear()
Worker._initialized_clients.clear()
# assert all(n.status == "closed" for n in Nanny._instances), {
# n: n.status for n in Nanny._instances
# }
WorkerTaskState._instances.clear()
Nanny._instances.clear()
_global_clients.clear()
Comm._instances.clear()
Expand Down
5 changes: 5 additions & 0 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4288,6 +4288,11 @@ def validate_state(self):
if self.transition_counter_max:
assert self.transition_counter < self.transition_counter_max

# Test that there aren't multiple TaskState objects with the same key in data_needed
assert len({ts.key for ts in self.data_needed}) == len(self.data_needed)
for tss in self.data_needed_per_worker.values():
assert len({ts.key for ts in tss}) == len(tss)

except Exception as e:
logger.error("Validate state failed", exc_info=e)
logger.exception(e)
Expand Down
18 changes: 11 additions & 7 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import sys
import weakref
from collections.abc import Collection, Container
from copy import copy
from dataclasses import dataclass, field
Expand Down Expand Up @@ -231,20 +232,23 @@ class TaskState:
#: True if the task is in memory or erred; False otherwise
done: bool = False

_instances: ClassVar[weakref.WeakSet[TaskState]] = weakref.WeakSet()

# Support for weakrefs to a class with __slots__
__weakref__: Any = field(init=False)

def __post_init__(self):
TaskState._instances.add(self)
Comment on lines +271 to +272
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work on unpickle.

Suggested change
def __post_init__(self):
TaskState._instances.add(self)
def __new__(cls, *args, **kwargs):
TaskState._instances.add(self)
return object.__new__(cls)

+ unit test for pickle/unpickle round trip

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not pickling TaskState objects anywhere, are we?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We aren't today, but that's not an excuse to have it future-proofed. Namely it would make a lot of sense to pickle dump our new WorkerState class and everything it contains.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the "excuse" is scope creep.

Using new as in

    def __new__(cls, *args, **kwargs):
        inst = object.__new__(cls)
        TaskState._instances.add(inst)
        return inst

does not work because we're defining the hash function as the hash of the key, i.e. we can only add fully initialized TaskState objects to the weakref.

Apart from this, we actually can (un-)pickle this class but will simply not add the instance to this weakset. For the only purpose of dumping this (like in our cluster dump) this is absolutely sufficient.
At this point and for this functionality I'm not willing to reconsider the hash function


def __repr__(self) -> str:
return f"<TaskState {self.key!r} {self.state}>"

def __eq__(self, other: object) -> bool:
if not isinstance(other, TaskState) or other.key != self.key:
return False
# When a task transitions to forgotten and exits Worker.tasks, it should be
# immediately dereferenced. If the same task is recreated later on on the
# worker, we should not have to deal with its previous incarnation lingering.
assert other is self
fjetter marked this conversation as resolved.
Show resolved Hide resolved
return True
# A task may be forgotten and a new TaskState object with the same key may be created in
# its place later on. In the Worker state, you should never have multiple TaskState objects with
# the same key. We can't assert it here however, as this comparison is also used in WeakSets
# for instance tracking purposes.
return other is self

def __hash__(self) -> int:
return hash(self.key)
Expand Down