Skip to content

Commit

Permalink
Use Task class instead of tuple (#8797)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Oct 24, 2024
1 parent af01543 commit 928d770
Show file tree
Hide file tree
Showing 22 changed files with 482 additions and 635 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ jobs:
# Increase this value to reset cache if
# continuous_integration/environment-${{ matrix.environment }}.yaml has not
# changed. See also same variable in .pre-commit-config.yaml
CACHE_NUMBER: 2
CACHE_NUMBER: 0
id: cache

- name: Update environment
Expand Down
5 changes: 3 additions & 2 deletions distributed/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,17 @@

from tornado.ioloop import IOLoop

from dask._task_spec import TaskRef

from distributed.client import Future
from distributed.protocol import to_serialize
from distributed.utils import LateLoopEvent, iscoroutinefunction, sync, thread_state
from distributed.utils_comm import WrappedKey
from distributed.worker import get_client, get_worker

_T = TypeVar("_T")


class Actor(WrappedKey):
class Actor(TaskRef):
"""Controls an object on a remote worker
An actor allows remote control of a stateful object living on a remote
Expand Down
50 changes: 30 additions & 20 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@
cast,
)

if TYPE_CHECKING:
from typing_extensions import TypeAlias

from packaging.version import parse as parse_version
from tlz import first, groupby, merge, partition_all, valmap

Expand All @@ -52,7 +49,6 @@
from dask.tokenize import tokenize
from dask.typing import Key, NoDefault, no_default
from dask.utils import (
apply,
ensure_dict,
format_bytes,
funcname,
Expand All @@ -74,6 +70,8 @@
from tornado import gen
from tornado.ioloop import IOLoop

from dask._task_spec import DataNode, GraphNode, Task, TaskRef

import distributed.utils
from distributed import cluster_dump, preloading
from distributed import versions as version_module
Expand Down Expand Up @@ -123,7 +121,6 @@
thread_state,
)
from distributed.utils_comm import (
WrappedKey,
gather_from_workers,
pack_data,
retry_operation,
Expand All @@ -132,6 +129,9 @@
)
from distributed.worker import get_client, get_worker, secede

if TYPE_CHECKING:
from typing_extensions import TypeAlias

logger = logging.getLogger(__name__)

_global_clients: weakref.WeakValueDictionary[int, Client] = (
Expand Down Expand Up @@ -250,7 +250,7 @@ def _del_global_client(c: Client) -> None:
pass


class Future(WrappedKey):
class Future(TaskRef):
"""A remotely running computation
A Future is a local proxy to a result running on a remote worker. A user
Expand Down Expand Up @@ -598,6 +598,9 @@ def __del__(self):
except RuntimeError: # closed event loop
pass

def __str__(self):
return repr(self)

def __repr__(self):
if self.type:
return (
Expand All @@ -616,6 +619,9 @@ def _repr_html_(self):
def __await__(self):
return self.result().__await__()

def __hash__(self):
return hash(self._id)


class FutureState:
"""A Future's internal state.
Expand Down Expand Up @@ -813,7 +819,7 @@ class VersionsDict(TypedDict):
client: dict[str, dict[str, Any]]


_T_LowLevelGraph: TypeAlias = dict[Key, tuple]
_T_LowLevelGraph: TypeAlias = dict[Key, GraphNode]


def _is_nested(iterable):
Expand Down Expand Up @@ -905,7 +911,7 @@ def get_ordered_keys(self):
def is_materialized(self) -> bool:
return hasattr(self, "_cached_dict")

def __getitem__(self, key: Key) -> tuple:
def __getitem__(self, key: Key) -> GraphNode:
return self._dict[key]

def __iter__(self) -> Iterator[Key]:
Expand All @@ -919,7 +925,7 @@ def _construct_graph(self) -> _T_LowLevelGraph:

if not self.kwargs:
dsk = {
key: (self.func,) + args
key: Task(key, self.func, *args)
for key, args in zip(self._keys, zip(*self.iterables))
}

Expand All @@ -928,15 +934,15 @@ def _construct_graph(self) -> _T_LowLevelGraph:
dsk = {}
for k, v in self.kwargs.items():
if sizeof(v) > 1e5:
vv = dask.delayed(v)
kwargs2[k] = vv._key
dsk.update(vv.dask)
vv = DataNode(k, v)
kwargs2[k] = vv.ref()
dsk[vv.key] = vv
else:
kwargs2[k] = v

dsk.update(
{
key: (apply, self.func, (tuple, list(args)), kwargs2)
key: Task(key, self.func, *args, **kwargs2)
for key, args in zip(self._keys, zip(*self.iterables))
}
)
Expand Down Expand Up @@ -2158,10 +2164,14 @@ def submit(
if isinstance(workers, (str, Number)):
workers = [workers]

if kwargs:
dsk = {key: (apply, func, list(args), kwargs)}
else:
dsk = {key: (func,) + tuple(args)}
dsk = {
key: Task(
key,
func,
*args,
**kwargs,
)
}
futures = self._graph_to_futures(
dsk,
[key],
Expand Down Expand Up @@ -3374,7 +3384,7 @@ def _graph_to_futures(
"op": "update-graph",
"graph_header": header,
"graph_frames": frames,
"keys": list(keys),
"keys": set(keys),
"internal_priority": internal_priority,
"submitting_task": getattr(thread_state, "key", None),
"fifo_timeout": fifo_timeout,
Expand Down Expand Up @@ -4460,7 +4470,7 @@ def dump_cluster_state(
self,
filename: str = "dask-cluster-dump",
write_from_scheduler: bool | None = None,
exclude: Collection[str] = ("run_spec",),
exclude: Collection[str] = (),
format: Literal["msgpack", "yaml"] = "msgpack",
**storage_options,
):
Expand Down Expand Up @@ -6100,7 +6110,7 @@ def futures_of(o, client=None):
stack.extend(x.values())
elif type(x) is SubgraphCallable:
stack.extend(x.dsk.values())
elif isinstance(x, WrappedKey):
elif isinstance(x, TaskRef):
if x not in seen:
seen.add(x)
futures.append(x)
Expand Down
4 changes: 3 additions & 1 deletion distributed/deploy/tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ async def test_repr():

@gen_test()
async def test_cluster_wait_for_worker():
async with LocalCluster(n_workers=2, asynchronous=True) as cluster:
async with LocalCluster(
n_workers=2, asynchronous=True, dashboard_address=":0"
) as cluster:
assert len(cluster.scheduler.workers) == 2
cluster.scale(4)
await cluster.wait_for_workers(4)
Expand Down
24 changes: 19 additions & 5 deletions distributed/deploy/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,11 @@ async def test_threads_per_worker_set_to_0():
Warning, match="Setting `threads_per_worker` to 0 has been deprecated."
):
async with LocalCluster(
n_workers=2, processes=False, threads_per_worker=0, asynchronous=True
n_workers=2,
processes=False,
threads_per_worker=0,
asynchronous=True,
dashboard_address=":0",
) as cluster:
assert len(cluster.workers) == 2
assert all(w.state.nthreads < CPU_COUNT for w in cluster.workers.values())
Expand Down Expand Up @@ -1170,7 +1174,10 @@ async def test_local_cluster_redundant_kwarg(nanny):
@gen_test()
async def test_cluster_info_sync():
async with LocalCluster(
processes=False, asynchronous=True, scheduler_sync_interval="1ms"
processes=False,
asynchronous=True,
scheduler_sync_interval="1ms",
dashboard_address=":0",
) as cluster:
assert cluster._cluster_info["name"] == cluster.name

Expand All @@ -1197,7 +1204,10 @@ async def test_cluster_info_sync():
@gen_test()
async def test_cluster_info_sync_is_robust_to_network_blips(monkeypatch):
async with LocalCluster(
processes=False, asynchronous=True, scheduler_sync_interval="1ms"
processes=False,
asynchronous=True,
scheduler_sync_interval="1ms",
dashboard_address=":0",
) as cluster:
assert cluster._cluster_info["name"] == cluster.name

Expand Down Expand Up @@ -1235,7 +1245,9 @@ async def error(*args, **kwargs):
@gen_test()
async def test_cluster_host_used_throughout_cluster(host, use_nanny):
"""Ensure that the `host` kwarg is propagated through scheduler, nanny, and workers"""
async with LocalCluster(host=host, asynchronous=True) as cluster:
async with LocalCluster(
host=host, asynchronous=True, dashboard_address=":0"
) as cluster:
url = urlparse(cluster.scheduler_address)
assert url.hostname == "127.0.0.1"
for worker in cluster.workers.values():
Expand All @@ -1249,7 +1261,9 @@ async def test_cluster_host_used_throughout_cluster(host, use_nanny):

@gen_test()
async def test_connect_to_closed_cluster():
async with LocalCluster(processes=False, asynchronous=True) as cluster:
async with LocalCluster(
processes=False, asynchronous=True, dashboard_address=":0"
) as cluster:
async with Client(cluster, asynchronous=True) as c1:
assert await c1.submit(inc, 1) == 2

Expand Down
37 changes: 14 additions & 23 deletions distributed/recreate_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from distributed.client import Future, futures_of, wait
from distributed.protocol.serialize import ToPickle
from distributed.utils import sync
from distributed.utils_comm import pack_data

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -40,11 +39,8 @@ def get_error_cause(self, *args, keys=(), **kwargs):

def get_runspec(self, *args, key=None, **kwargs):
key = self._process_key(key)
ts = self.scheduler.tasks.get(key)
return {
"task": ToPickle(ts.run_spec),
"deps": [dts.key for dts in ts.dependencies],
}
ts = self.scheduler.tasks[key]
return ToPickle(ts.run_spec)


class ReplayTaskClient:
Expand All @@ -61,10 +57,8 @@ def __init__(self, client):
self.client = client
self.client.extensions["replay-tasks"] = self
# monkey patch
self.client._get_raw_components_from_future = (
self._get_raw_components_from_future
)
self.client._prepare_raw_components = self._prepare_raw_components
self.client._get_raw_components_from_future = self._get_task_runspec
self.client._prepare_raw_components = self._get_dependencies
self.client._get_components_from_future = self._get_components_from_future
self.client._get_errored_future = self._get_errored_future
self.client.recreate_task_locally = self.recreate_task_locally
Expand All @@ -74,7 +68,7 @@ def __init__(self, client):
def scheduler(self):
return self.client.scheduler

async def _get_raw_components_from_future(self, future):
async def _get_task_runspec(self, future):
"""
For a given future return the func, args and kwargs and future
deps that would be executed remotely.
Expand All @@ -85,28 +79,25 @@ async def _get_raw_components_from_future(self, future):
else:
validate_key(future)
key = future
spec = await self.scheduler.get_runspec(key=key)
return (*spec["task"], spec["deps"])
run_spec = await self.scheduler.get_runspec(key=key)
return run_spec

async def _prepare_raw_components(self, raw_components):
async def _get_dependencies(self, dependencies):
"""
Take raw components and resolve future dependencies.
"""
function, args, kwargs, deps = raw_components
futures = self.client._graph_to_futures({}, deps, span_metadata={})
futures = self.client._graph_to_futures({}, dependencies, span_metadata={})
data = await self.client._gather(futures)
args = pack_data(args, data)
kwargs = pack_data(kwargs, data)
return (function, args, kwargs)
return data

async def _get_components_from_future(self, future):
"""
For a given future return the func, args and kwargs that would be
executed remotely. Any args/kwargs that are themselves futures will
be resolved to the return value of those futures.
"""
raw_components = await self._get_raw_components_from_future(future)
return await self._prepare_raw_components(raw_components)
runspec = await self._get_task_runspec(future)
return runspec, await self._get_dependencies(runspec.dependencies)

def recreate_task_locally(self, future):
"""
Expand Down Expand Up @@ -137,10 +128,10 @@ def recreate_task_locally(self, future):
-------
Any; will return the result of the task future.
"""
func, args, kwargs = sync(
runspec, dependencies = sync(
self.client.loop, self._get_components_from_future, future
)
return func(*args, **kwargs)
return runspec(dependencies)

async def _get_errored_future(self, future):
"""
Expand Down
Loading

0 comments on commit 928d770

Please sign in to comment.