diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 2deffff9d6..262f320ec8 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -161,7 +161,7 @@ jobs: # Increase this value to reset cache if # continuous_integration/environment-${{ matrix.environment }}.yaml has not # changed. See also same variable in .pre-commit-config.yaml - CACHE_NUMBER: 2 + CACHE_NUMBER: 0 id: cache - name: Update environment diff --git a/distributed/actor.py b/distributed/actor.py index 0af83daf63..69dd5add64 100644 --- a/distributed/actor.py +++ b/distributed/actor.py @@ -10,16 +10,17 @@ from tornado.ioloop import IOLoop +from dask._task_spec import TaskRef + from distributed.client import Future from distributed.protocol import to_serialize from distributed.utils import LateLoopEvent, iscoroutinefunction, sync, thread_state -from distributed.utils_comm import WrappedKey from distributed.worker import get_client, get_worker _T = TypeVar("_T") -class Actor(WrappedKey): +class Actor(TaskRef): """Controls an object on a remote worker An actor allows remote control of a stateful object living on a remote diff --git a/distributed/client.py b/distributed/client.py index 58781109da..be89e2cf8d 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -37,9 +37,6 @@ cast, ) -if TYPE_CHECKING: - from typing_extensions import TypeAlias - from packaging.version import parse as parse_version from tlz import first, groupby, merge, partition_all, valmap @@ -52,7 +49,6 @@ from dask.tokenize import tokenize from dask.typing import Key, NoDefault, no_default from dask.utils import ( - apply, ensure_dict, format_bytes, funcname, @@ -74,6 +70,8 @@ from tornado import gen from tornado.ioloop import IOLoop +from dask._task_spec import DataNode, GraphNode, Task, TaskRef + import distributed.utils from distributed import cluster_dump, preloading from distributed import versions as version_module @@ -123,7 +121,6 @@ thread_state, ) from distributed.utils_comm import ( - WrappedKey, gather_from_workers, pack_data, retry_operation, @@ -132,6 +129,9 @@ ) from distributed.worker import get_client, get_worker, secede +if TYPE_CHECKING: + from typing_extensions import TypeAlias + logger = logging.getLogger(__name__) _global_clients: weakref.WeakValueDictionary[int, Client] = ( @@ -250,7 +250,7 @@ def _del_global_client(c: Client) -> None: pass -class Future(WrappedKey): +class Future(TaskRef): """A remotely running computation A Future is a local proxy to a result running on a remote worker. A user @@ -598,6 +598,9 @@ def __del__(self): except RuntimeError: # closed event loop pass + def __str__(self): + return repr(self) + def __repr__(self): if self.type: return ( @@ -616,6 +619,9 @@ def _repr_html_(self): def __await__(self): return self.result().__await__() + def __hash__(self): + return hash(self._id) + class FutureState: """A Future's internal state. @@ -813,7 +819,7 @@ class VersionsDict(TypedDict): client: dict[str, dict[str, Any]] -_T_LowLevelGraph: TypeAlias = dict[Key, tuple] +_T_LowLevelGraph: TypeAlias = dict[Key, GraphNode] def _is_nested(iterable): @@ -905,7 +911,7 @@ def get_ordered_keys(self): def is_materialized(self) -> bool: return hasattr(self, "_cached_dict") - def __getitem__(self, key: Key) -> tuple: + def __getitem__(self, key: Key) -> GraphNode: return self._dict[key] def __iter__(self) -> Iterator[Key]: @@ -919,7 +925,7 @@ def _construct_graph(self) -> _T_LowLevelGraph: if not self.kwargs: dsk = { - key: (self.func,) + args + key: Task(key, self.func, *args) for key, args in zip(self._keys, zip(*self.iterables)) } @@ -928,15 +934,15 @@ def _construct_graph(self) -> _T_LowLevelGraph: dsk = {} for k, v in self.kwargs.items(): if sizeof(v) > 1e5: - vv = dask.delayed(v) - kwargs2[k] = vv._key - dsk.update(vv.dask) + vv = DataNode(k, v) + kwargs2[k] = vv.ref() + dsk[vv.key] = vv else: kwargs2[k] = v dsk.update( { - key: (apply, self.func, (tuple, list(args)), kwargs2) + key: Task(key, self.func, *args, **kwargs2) for key, args in zip(self._keys, zip(*self.iterables)) } ) @@ -2158,10 +2164,14 @@ def submit( if isinstance(workers, (str, Number)): workers = [workers] - if kwargs: - dsk = {key: (apply, func, list(args), kwargs)} - else: - dsk = {key: (func,) + tuple(args)} + dsk = { + key: Task( + key, + func, + *args, + **kwargs, + ) + } futures = self._graph_to_futures( dsk, [key], @@ -3374,7 +3384,7 @@ def _graph_to_futures( "op": "update-graph", "graph_header": header, "graph_frames": frames, - "keys": list(keys), + "keys": set(keys), "internal_priority": internal_priority, "submitting_task": getattr(thread_state, "key", None), "fifo_timeout": fifo_timeout, @@ -4460,7 +4470,7 @@ def dump_cluster_state( self, filename: str = "dask-cluster-dump", write_from_scheduler: bool | None = None, - exclude: Collection[str] = ("run_spec",), + exclude: Collection[str] = (), format: Literal["msgpack", "yaml"] = "msgpack", **storage_options, ): @@ -6100,7 +6110,7 @@ def futures_of(o, client=None): stack.extend(x.values()) elif type(x) is SubgraphCallable: stack.extend(x.dsk.values()) - elif isinstance(x, WrappedKey): + elif isinstance(x, TaskRef): if x not in seen: seen.add(x) futures.append(x) diff --git a/distributed/deploy/tests/test_cluster.py b/distributed/deploy/tests/test_cluster.py index b40a15f2c4..ec3512d0b4 100644 --- a/distributed/deploy/tests/test_cluster.py +++ b/distributed/deploy/tests/test_cluster.py @@ -36,7 +36,9 @@ async def test_repr(): @gen_test() async def test_cluster_wait_for_worker(): - async with LocalCluster(n_workers=2, asynchronous=True) as cluster: + async with LocalCluster( + n_workers=2, asynchronous=True, dashboard_address=":0" + ) as cluster: assert len(cluster.scheduler.workers) == 2 cluster.scale(4) await cluster.wait_for_workers(4) diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 59abc15573..f4b422af62 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -1066,7 +1066,11 @@ async def test_threads_per_worker_set_to_0(): Warning, match="Setting `threads_per_worker` to 0 has been deprecated." ): async with LocalCluster( - n_workers=2, processes=False, threads_per_worker=0, asynchronous=True + n_workers=2, + processes=False, + threads_per_worker=0, + asynchronous=True, + dashboard_address=":0", ) as cluster: assert len(cluster.workers) == 2 assert all(w.state.nthreads < CPU_COUNT for w in cluster.workers.values()) @@ -1170,7 +1174,10 @@ async def test_local_cluster_redundant_kwarg(nanny): @gen_test() async def test_cluster_info_sync(): async with LocalCluster( - processes=False, asynchronous=True, scheduler_sync_interval="1ms" + processes=False, + asynchronous=True, + scheduler_sync_interval="1ms", + dashboard_address=":0", ) as cluster: assert cluster._cluster_info["name"] == cluster.name @@ -1197,7 +1204,10 @@ async def test_cluster_info_sync(): @gen_test() async def test_cluster_info_sync_is_robust_to_network_blips(monkeypatch): async with LocalCluster( - processes=False, asynchronous=True, scheduler_sync_interval="1ms" + processes=False, + asynchronous=True, + scheduler_sync_interval="1ms", + dashboard_address=":0", ) as cluster: assert cluster._cluster_info["name"] == cluster.name @@ -1235,7 +1245,9 @@ async def error(*args, **kwargs): @gen_test() async def test_cluster_host_used_throughout_cluster(host, use_nanny): """Ensure that the `host` kwarg is propagated through scheduler, nanny, and workers""" - async with LocalCluster(host=host, asynchronous=True) as cluster: + async with LocalCluster( + host=host, asynchronous=True, dashboard_address=":0" + ) as cluster: url = urlparse(cluster.scheduler_address) assert url.hostname == "127.0.0.1" for worker in cluster.workers.values(): @@ -1249,7 +1261,9 @@ async def test_cluster_host_used_throughout_cluster(host, use_nanny): @gen_test() async def test_connect_to_closed_cluster(): - async with LocalCluster(processes=False, asynchronous=True) as cluster: + async with LocalCluster( + processes=False, asynchronous=True, dashboard_address=":0" + ) as cluster: async with Client(cluster, asynchronous=True) as c1: assert await c1.submit(inc, 1) == 2 diff --git a/distributed/recreate_tasks.py b/distributed/recreate_tasks.py index 78adfda09f..f4ccdd0106 100644 --- a/distributed/recreate_tasks.py +++ b/distributed/recreate_tasks.py @@ -7,7 +7,6 @@ from distributed.client import Future, futures_of, wait from distributed.protocol.serialize import ToPickle from distributed.utils import sync -from distributed.utils_comm import pack_data logger = logging.getLogger(__name__) @@ -40,11 +39,8 @@ def get_error_cause(self, *args, keys=(), **kwargs): def get_runspec(self, *args, key=None, **kwargs): key = self._process_key(key) - ts = self.scheduler.tasks.get(key) - return { - "task": ToPickle(ts.run_spec), - "deps": [dts.key for dts in ts.dependencies], - } + ts = self.scheduler.tasks[key] + return ToPickle(ts.run_spec) class ReplayTaskClient: @@ -61,10 +57,8 @@ def __init__(self, client): self.client = client self.client.extensions["replay-tasks"] = self # monkey patch - self.client._get_raw_components_from_future = ( - self._get_raw_components_from_future - ) - self.client._prepare_raw_components = self._prepare_raw_components + self.client._get_raw_components_from_future = self._get_task_runspec + self.client._prepare_raw_components = self._get_dependencies self.client._get_components_from_future = self._get_components_from_future self.client._get_errored_future = self._get_errored_future self.client.recreate_task_locally = self.recreate_task_locally @@ -74,7 +68,7 @@ def __init__(self, client): def scheduler(self): return self.client.scheduler - async def _get_raw_components_from_future(self, future): + async def _get_task_runspec(self, future): """ For a given future return the func, args and kwargs and future deps that would be executed remotely. @@ -85,19 +79,16 @@ async def _get_raw_components_from_future(self, future): else: validate_key(future) key = future - spec = await self.scheduler.get_runspec(key=key) - return (*spec["task"], spec["deps"]) + run_spec = await self.scheduler.get_runspec(key=key) + return run_spec - async def _prepare_raw_components(self, raw_components): + async def _get_dependencies(self, dependencies): """ Take raw components and resolve future dependencies. """ - function, args, kwargs, deps = raw_components - futures = self.client._graph_to_futures({}, deps, span_metadata={}) + futures = self.client._graph_to_futures({}, dependencies, span_metadata={}) data = await self.client._gather(futures) - args = pack_data(args, data) - kwargs = pack_data(kwargs, data) - return (function, args, kwargs) + return data async def _get_components_from_future(self, future): """ @@ -105,8 +96,8 @@ async def _get_components_from_future(self, future): executed remotely. Any args/kwargs that are themselves futures will be resolved to the return value of those futures. """ - raw_components = await self._get_raw_components_from_future(future) - return await self._prepare_raw_components(raw_components) + runspec = await self._get_task_runspec(future) + return runspec, await self._get_dependencies(runspec.dependencies) def recreate_task_locally(self, future): """ @@ -137,10 +128,10 @@ def recreate_task_locally(self, future): ------- Any; will return the result of the task future. """ - func, args, kwargs = sync( + runspec, dependencies = sync( self.client.loop, self._get_components_from_future, future ) - return func(*args, **kwargs) + return runspec(dependencies) async def _get_errored_future(self, future): """ diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 38834988ca..35f2aecb22 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -54,8 +54,14 @@ import dask import dask.utils -from dask.core import get_deps, iskey, validate_key -from dask.tokenize import TokenizationError, normalize_token, tokenize +from dask._task_spec import ( + DependenciesMapping, + GraphNode, + convert_legacy_graph, + resolve_aliases, +) +from dask.base import TokenizationError, normalize_token, tokenize +from dask.core import istask, reverse_dict, validate_key from dask.typing import Key, no_default from dask.utils import ( _deprecated, @@ -135,10 +141,8 @@ gather_from_workers, retry_operation, scatter_to_workers, - unpack_remotedata, ) from distributed.variable import VariableExtension -from distributed.worker import _normalize_task if TYPE_CHECKING: # TODO import from typing (requires Python >=3.10) @@ -169,7 +173,7 @@ # (recommendations, client messages, worker messages) RecsMsgs: TypeAlias = tuple[Recs, Msgs, Msgs] -T_runspec: TypeAlias = tuple[Callable, tuple, dict[str, Any]] +T_runspec: TypeAlias = GraphNode logger = logging.getLogger(__name__) LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") @@ -3176,13 +3180,14 @@ def get_task_duration(self, ts: TaskState) -> float: for this task, `distributed.scheduler.unknown-task-duration` is used instead. """ - duration: float = ts.prefix.duration_average + prefix = ts.prefix + duration: float = prefix.duration_average if duration >= 0: return duration - s = self.unknown_durations.get(ts.prefix.name) + s = self.unknown_durations.get(prefix.name) if s is None: - self.unknown_durations[ts.prefix.name] = s = set() + self.unknown_durations[prefix.name] = s = set() s.add(ts) return self.UNKNOWN_TASK_DURATION @@ -3580,7 +3585,7 @@ def _task_to_msg(self, ts: TaskState, duration: float = -1) -> dict[str, Any]: "duration": duration, "stimulus_id": f"compute-task-{time()}", "who_has": { - dts.key: [ws.address for ws in dts.who_has or ()] + dts.key: tuple(ws.address for ws in (dts.who_has or ())) for dts in ts.dependencies }, "nbytes": {dts.key: dts.nbytes for dts in ts.dependencies}, @@ -4607,58 +4612,33 @@ async def add_nanny(self, comm: Comm, address: str) -> None: self._starting_nannies.discard(address) self._starting_nannies_cond.notify_all() - def _match_graph_with_tasks( + def _find_lost_dependencies( self, dsk: dict[Key, T_runspec], dependencies: dict[Key, set[Key]], keys: set[Key], ) -> set[Key]: - n = -1 lost_keys = set() - while len(dsk) != n: # walk through new tasks, cancel any bad deps - n = len(dsk) - for k, deps in list(dependencies.items()): - if (k not in self.tasks and k not in dsk) or any( - dep not in self.tasks and dep not in dsk for dep in deps - ): # bad key - lost_keys.add(k) - logger.info("User asked for computation on lost data, %s", k) - dsk.pop(k, None) - del dependencies[k] - if k in keys: - keys.remove(k) - del deps - # Avoid computation that is already finished - done = set() # tasks that are already done - for k, v in dependencies.items(): - if v and k in self.tasks: - ts = self.tasks[k] - if ts.state in ("memory", "erred"): - done.add(k) - - if done: - dependents = dask.core.reverse_dict(dependencies) - stack = list(done) - while stack: # remove unnecessary dependencies - key = stack.pop() - try: - deps = dependencies[key] - except KeyError: - deps = {ts.key for ts in self.tasks[key].dependencies} - for dep in deps: - if dep in dependents: - child_deps = dependents[dep] - elif dep in self.tasks: - child_deps = {ts.key for ts in self.tasks[key].dependencies} - else: - child_deps = set() - if all(d in done for d in child_deps): - if dep in self.tasks and dep not in done: - done.add(dep) - stack.append(dep) - for anc in done: - dsk.pop(anc, None) - dependencies.pop(anc, None) + seen: set[Key] = set() + sadd = seen.add + for k in list(keys): + work = {k} + wpop = work.pop + wupdate = work.update + while work: + d = wpop() + if d in seen: + continue + sadd(d) + if d not in dsk: + if d not in self.tasks: + lost_keys.add(d) + lost_keys.add(k) + logger.info("User asked for computation on lost data, %s", k) + dependencies.pop(d, None) + keys.discard(k) + continue + wupdate(dsk[d].dependencies) return lost_keys def _create_taskstate_from_graph( @@ -4679,7 +4659,7 @@ def _create_taskstate_from_graph( actors: bool | list[Key] | None = None, fifo_timeout: float = 0.0, code: tuple[SourceCode, ...] = (), - ) -> None: + ) -> dict[str, float]: """ Take a low level graph and create the necessary scheduler state to compute it. @@ -4691,14 +4671,6 @@ def _create_taskstate_from_graph( in the same event loop tick. """ - lost_keys = self._match_graph_with_tasks(dsk, dependencies, keys) - - if lost_keys: - self.report({"op": "cancelled-keys", "keys": lost_keys}, client=client) - self.client_releases_keys( - keys=lost_keys, client=client, stimulus_id=stimulus_id - ) - if not self.is_idle and self.computations: # Still working on something. Assign new tasks to same computation computation = self.computations[-1] @@ -4713,7 +4685,6 @@ def _create_taskstate_from_graph( # annotations. computation.annotations.update(global_annotations) del global_annotations - ( runnable, touched_tasks, @@ -4726,15 +4697,11 @@ def _create_taskstate_from_graph( computation=computation, ) - if len(dsk) > 1 or colliding_task_count: - self.log_event( - ["all", client], - { - "action": "update_graph", - "count": len(dsk), - "key-collisions": colliding_task_count, - }, - ) + metrics = { + "tasks": len(dsk), + "new_tasks": len(new_tasks), + "key_collisions": colliding_task_count, + } keys_with_annotations = self._apply_annotations( tasks=new_tasks, @@ -4823,6 +4790,44 @@ def _create_taskstate_from_graph( if ts.state in ("memory", "erred"): self.report_on_key(ts=ts, client=client) + return metrics + + def _remove_done_tasks_from_dsk( + self, + dsk: dict[Key, T_runspec], + dependencies: dict[Key, set[Key]], + ) -> None: + # Avoid computation that is already finished + done = set() # tasks that are already done + for k, v in dependencies.items(): + if v and k in self.tasks: + ts = self.tasks[k] + if ts.state in ("memory", "erred"): + done.add(k) + if done: + dependents = dask.core.reverse_dict(dependencies) + stack = list(done) + while stack: # remove unnecessary dependencies + key = stack.pop() + try: + deps = dependencies[key] + except KeyError: + deps = {ts.key for ts in self.tasks[key].dependencies} + for dep in deps: + if dep in dependents: + child_deps = dependents[dep] + elif dep in self.tasks: + child_deps = {ts.key for ts in self.tasks[key].dependencies} + else: + child_deps = set() + if all(d in done for d in child_deps): + if dep in self.tasks and dep not in done: + done.add(dep) + stack.append(dep) + for anc in done: + dsk.pop(anc, None) + dependencies.pop(anc, None) + @log_errors async def update_graph( self, @@ -4843,6 +4848,7 @@ async def update_graph( start = time() self._active_graph_updates += 1 try: + logger.debug("Received new graph. Deserializing...") try: graph = deserialize(graph_header, graph_frames).data del graph_header, graph_frames @@ -4862,9 +4868,23 @@ async def update_graph( _materialize_graph, graph=graph, global_annotations=annotations or {}, + keys=keys, validate=self.validate, ) + + materialization_done = time() + logger.debug("Materialization done. Got %i tasks.", len(dsk)) del graph + + lost_keys = self._find_lost_dependencies(dsk, dependencies, keys) + + if lost_keys: + self.report({"op": "cancelled-keys", "keys": lost_keys}, client=client) + self.client_releases_keys( + keys=lost_keys, client=client, stimulus_id=stimulus_id + ) + dsk = _cull(dsk, keys) + if not internal_priority: # Removing all non-local keys before calling order() dsk_keys = set( @@ -4875,12 +4895,18 @@ async def update_graph( for k, v in dependencies.items() if k in dsk_keys } + internal_priority = await offload( dask.order.order, dsk=dsk, dependencies=stripped_deps ) - dsk = valmap(_normalize_task, dsk) + ordering_done = time() + logger.debug("Ordering done.") + + before = len(self.tasks) - self._create_taskstate_from_graph( + self._remove_done_tasks_from_dsk(dsk, dependencies) + + metrics = self._create_taskstate_from_graph( dsk=dsk, client=client, dependencies=dependencies, @@ -4899,7 +4925,32 @@ async def update_graph( start=start, stimulus_id=stimulus_id or f"update-graph-{start}", ) - except RuntimeError as e: + task_state_created = time() + metrics.update( + { + "start_timestamp_seconds": start, + "materialization_duration_seconds": materialization_done - start, + "ordering_duration_seconds": materialization_done - ordering_done, + "state_initialization_duration_seconds": ordering_done + - task_state_created, + "duration_seconds": task_state_created - start, + } + ) + evt_msg = { + "action": "update-graph", + "stimulus_id": stimulus_id, + "metrics": metrics, + "status": "OK", + } + self.log_event(["scheduler", client], evt_msg) + logger.debug("Task state created. %i new tasks", len(self.tasks) - before) + except Exception as e: + evt_msg = { + "action": "update-graph", + "stimulus_id": stimulus_id, + "status": "error", + } + self.log_event(["scheduler", client], evt_msg) logger.error(str(e)) err = error_message(e) for key in keys: @@ -4928,7 +4979,7 @@ def _generate_taskstates( computation: Computation, ) -> tuple: # Get or create task states - runnable = [] + runnable = list() new_tasks = [] stack = list(keys) touched_keys = set() @@ -5110,7 +5161,7 @@ def _set_priorities( user_priority: int | dict[Key, int], fifo_timeout: int | float | str, start: float, - tasks: list[TaskState], + tasks: set[TaskState], ) -> None: fifo_timeout = parse_timedelta(fifo_timeout) if submitting_task: # sub-tasks get better priority than parent tasks @@ -5147,7 +5198,7 @@ def _set_priorities( internal_priority[ts.key], ) - if self.validate and ts.run_spec: + if self.validate and istask(ts.run_spec): assert isinstance(ts.priority, tuple) and all( isinstance(el, (int, float)) for el in ts.priority ) @@ -5971,7 +6022,8 @@ def handle_task_finished( ) -> None: if worker not in self.workers: return - self.validate_key(key) + if self.validate: + self.validate_key(key) r: tuple = self.stimulus_task_finished( key=key, worker=worker, stimulus_id=stimulus_id, **msg @@ -9333,7 +9385,10 @@ def transition( def _materialize_graph( - graph: HighLevelGraph, global_annotations: dict[str, Any], validate: bool + graph: HighLevelGraph, + global_annotations: dict[str, Any], + validate: bool, + keys: set[Key], ) -> tuple[dict[Key, T_runspec], dict[Key, set[Key]], dict[str, dict[Key, Any]]]: dsk: dict = ensure_dict(graph) if validate: @@ -9352,33 +9407,35 @@ def _materialize_graph( annotations_by_type[annot_type].update( {k: (value(k) if callable(value) else value) for k in layer} ) - dependencies, _ = get_deps(dsk) - # Remove `Future` objects from graph and note any future dependencies + dsk2 = convert_legacy_graph(dsk) + dependents = reverse_dict(DependenciesMapping(dsk2)) + # This is removing weird references like "x-foo": "foo" which often make up + # a substantial part of the graph + # This also performs culling! + dsk3 = resolve_aliases(dsk2, keys, dependents) + + logger.debug( + "Removing aliases. Started with %i and got %i left", len(dsk2), len(dsk3) + ) + # FIXME: There should be no need to fully materialize and copy this but some + # sections in the scheduler are mutating it. + dependencies = {k: set(v) for k, v in DependenciesMapping(dsk3).items()} + return dsk3, dependencies, annotations_by_type + + +def _cull(dsk: dict[Key, GraphNode], keys: set[Key]) -> dict[Key, GraphNode]: + work = set(keys) + seen: set[Key] = set() dsk2 = {} - fut_deps = {} - for k, v in dsk.items(): - v, futs = unpack_remotedata(v, byte_keys=True) - if futs: - fut_deps[k] = futs - - # Remove aliases {x: x}. - # FIXME: This is an artifact generated by unpack_remotedata when using persisted - # collections. There should be a better way to achieve that tasks are not self - # referencing themselves. - if not iskey(v) or v != k: - dsk2[k] = v - - dsk = dsk2 - - # - Add in deps for any tasks that depend on futures - for k, futures in fut_deps.items(): - dependencies[k].update(f.key for f in futures) - - # Remove any self-dependencies (happens on test_publish_bag() and others) - for k, v in dependencies.items(): - deps = set(v) - deps.discard(k) - dependencies[k] = deps - - return dsk, dependencies, annotations_by_type + wpop = work.pop + wupdate = work.update + sadd = seen.add + while work: + k = wpop() + if k in seen or k not in dsk: + continue + sadd(k) + dsk2[k] = v = dsk[k] + wupdate(v.dependencies) + return dsk2 diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index f62b375c0c..20644e5acb 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -121,6 +121,7 @@ import dask import dask.config +from dask._task_spec import Task, TaskRef from dask.highlevelgraph import HighLevelGraph from dask.layers import Layer from dask.tokenize import tokenize @@ -143,7 +144,6 @@ from distributed.shuffle._shuffle import barrier_key, shuffle_barrier from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin from distributed.sizeof import sizeof -from distributed.utils_comm import DoNotUnpack if TYPE_CHECKING: import numpy as np @@ -746,19 +746,20 @@ def partial_concatenate( ) if _slicing_is_necessary(old_slice, original_shape): key = (slice_group,) + ndpartial.ix + old_global_index - rec_cat_arg[old_partial_index] = key - dsk[key] = ( + dsk[key] = t = Task( + key, getitem, - (input_name,) + old_global_index, + TaskRef((input_name,) + old_global_index), old_slice, ) + rec_cat_arg[old_partial_index] = t.ref() else: - rec_cat_arg[old_partial_index] = (input_name,) + old_global_index + rec_cat_arg[old_partial_index] = TaskRef((input_name,) + old_global_index) - dsk[(rechunk_name(token),) + global_new_index] = ( - concatenate3, - rec_cat_arg.tolist(), + concat_task = Task( + (rechunk_name(token),) + global_new_index, concatenate3, rec_cat_arg.tolist() ) + dsk[concat_task.key] = concat_task return dsk @@ -806,31 +807,36 @@ def partial_rechunk( for global_index in _ndindices_of_slice(ndpartial.old): partial_index = _partial_index(global_index, old_partial_offset) - input_key = (input_name,) + global_index + input_key = TaskRef((input_name,) + global_index) key = (transfer_group,) + ndpartial.ix + global_index - transfer_keys.append(key) - dsk[key] = ( + dsk[key] = t = Task( + key, rechunk_transfer, input_key, partial_token, - DoNotUnpack(partial_index), - DoNotUnpack(partial_new), - DoNotUnpack(partial_old), + partial_index, + partial_new, + partial_old, disk, ) + transfer_keys.append(t.ref()) - dsk[_barrier_key] = (shuffle_barrier, partial_token, transfer_keys) + dsk[_barrier_key] = barrier = Task( + _barrier_key, shuffle_barrier, partial_token, transfer_keys + ) new_partial_offset = tuple(axis.start for axis in ndpartial.new) for global_index in _ndindices_of_slice(ndpartial.new): partial_index = _partial_index(global_index, new_partial_offset) if keepmap[global_index]: - dsk[(unpack_group,) + global_index] = ( + k = (unpack_group,) + global_index + dsk[k] = Task( + k, rechunk_unpack, partial_token, partial_index, - _barrier_key, + barrier.ref(), ) return dsk diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 57e5493d25..ec942dfe60 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -21,10 +21,12 @@ from tornado.ioloop import IOLoop import dask +from dask._task_spec import GraphNode, Task, TaskRef from dask.highlevelgraph import HighLevelGraph from dask.layers import Layer from dask.tokenize import tokenize from dask.typing import Key +from dask.utils import is_dataframe_like from distributed.core import PooledRPCCall from distributed.exceptions import Reschedule @@ -161,7 +163,7 @@ def rearrange_by_column_p2p( ) -_T_LowLevelGraph: TypeAlias = dict[Key, tuple] +_T_LowLevelGraph: TypeAlias = dict[Key, GraphNode] class P2PShuffleLayer(Layer): @@ -226,7 +228,7 @@ def _dict(self) -> _T_LowLevelGraph: self._cached_dict = dsk return self._cached_dict - def __getitem__(self, key: Key) -> tuple: + def __getitem__(self, key: Key) -> GraphNode: return self._dict[key] def __iter__(self) -> Iterator[Key]: @@ -288,10 +290,10 @@ def _construct_graph(self) -> _T_LowLevelGraph: name = "shuffle-transfer-" + token transfer_keys = list() for i in range(self.npartitions_input): - transfer_keys.append((name, i)) - dsk[(name, i)] = ( + t = Task( + (name, i), shuffle_transfer, - (self.name_input, i), + TaskRef((self.name_input, i)), token, i, self.npartitions, @@ -301,17 +303,22 @@ def _construct_graph(self) -> _T_LowLevelGraph: self.disk, self.drop_column, ) + dsk[t.key] = t + transfer_keys.append(t.ref()) - dsk[_barrier_key] = (shuffle_barrier, token, transfer_keys) + barrier = Task(_barrier_key, shuffle_barrier, token, transfer_keys) + dsk[barrier.key] = barrier name = self.name for part_out in self.parts_out: - dsk[(name, part_out)] = ( + t = Task( + (name, part_out), shuffle_unpack, token, part_out, - _barrier_key, + barrier.ref(), ) + dsk[t.key] = t return dsk @@ -580,6 +587,8 @@ def pick_worker(self, partition: int, workers: Sequence[str]) -> str: return _get_worker_for_range_sharding(self.npartitions, partition, workers) def validate_data(self, data: pd.DataFrame) -> None: + if not is_dataframe_like(data): + raise TypeError(f"Expected {data=} to be a DataFrame, got {type(data)}.") if set(data.columns) != set(self.meta.columns): raise ValueError(f"Expected {self.meta.columns=} to match {data.columns=}.") diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 615f83b6b0..6451f82079 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -622,6 +622,7 @@ async def test_restarting_does_not_deadlock(c, s): assert dd.assert_eq(result, expected) +@pytest.mark.slow @gen_cluster(client=True, nthreads=[("", 1)] * 2) async def test_closed_input_only_worker_during_transfer(c, s, a, b): def mock_get_worker_for_range_sharding( diff --git a/distributed/tests/test_active_memory_manager.py b/distributed/tests/test_active_memory_manager.py index cb573c3672..4f55158622 100644 --- a/distributed/tests/test_active_memory_manager.py +++ b/distributed/tests/test_active_memory_manager.py @@ -1266,7 +1266,7 @@ def run(self): # Instead of yielding ("drop", ts, None) for each worker, which would result # in semi-predictable output about which replica survives, randomly choose a # different survivor at each AMM run. - candidates = list(ts.who_has) + candidates = list(ts.who_has or ()) random.shuffle(candidates) for ws in candidates: yield "drop", ts, {ws} @@ -1288,13 +1288,19 @@ async def tensordot_stress(c, s): warnings.simplefilter("ignore") b = (a @ a.T).sum().round(3) assert await c.compute(b) == 245.394 - + expected_tasks = -1 + for _, msg in await c.get_events("scheduler"): + if msg["action"] == "update-graph": + assert msg["status"] == "OK", msg + expected_tasks = msg["metrics"]["tasks"] + break + else: + raise RuntimeError("Expected 'update_graph' event not found") # Test that we didn't recompute any tasks during the stress test await async_poll_for(lambda: not s.tasks, timeout=5) - assert sum(t.start == "memory" for t in s.transition_log) == 1639 + assert sum(t.start == "memory" for t in s.transition_log) == expected_tasks -@pytest.mark.slow @gen_cluster( client=True, nthreads=[("", 1)] * 4, @@ -1350,7 +1356,6 @@ async def test_ReduceReplicas_stress(c, s, *workers): await tensordot_stress(c, s) -@pytest.mark.slow @pytest.mark.parametrize("use_ReduceReplicas", [False, True]) @gen_cluster( client=True, diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 0016294e70..f80130224e 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -42,6 +42,7 @@ import dask import dask.bag as db from dask import delayed +from dask._task_spec import no_function_cache from dask.optimization import SubgraphCallable from dask.tokenize import tokenize from dask.utils import get_default_shuffle_method, parse_timedelta, tmpfile @@ -683,7 +684,9 @@ async def test_get(c, s, a, b): assert result == [] result = await c.get( - {("x", 1): (inc, 1), ("x", 2): (inc, ("x", 1))}, ("x", 2), sync=False + {("x", 1): (inc, 1), ("x", 2): (inc, ("x", 1))}, + ("x", 2), + sync=False, ) assert result == 3 @@ -2883,8 +2886,9 @@ def test_persist_get_sync(c): assert xxyy3.compute() == ((1 + 1) + (2 + 2)) + 10 +@pytest.mark.parametrize("do_wait", [True, False]) @gen_cluster(client=True) -async def test_persist_get(c, s, a, b): +async def test_persist_get(c, s, a, b, do_wait): x, y = delayed(1), delayed(2) xx = delayed(add)(x, x) yy = delayed(add)(y, y) @@ -2892,8 +2896,8 @@ async def test_persist_get(c, s, a, b): xxyy2 = c.persist(xxyy) xxyy3 = delayed(add)(xxyy2, 10) - - await asyncio.sleep(0.5) + if do_wait: + await wait(xxyy2) result = await c.gather(c.get(xxyy3.dask, xxyy3.__dask_keys__(), sync=False)) assert result[0] == ((1 + 1) + (2 + 2)) + 10 @@ -4712,8 +4716,8 @@ def f(L=None): assert result == 1 + 2 + 3 -@gen_cluster(client=True) -async def test_map_list_kwargs(c, s, a, b): +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_map_list_kwargs(c, s, a): futures = await c.scatter([1, 2, 3]) def f(i, L=None): @@ -4724,97 +4728,6 @@ def f(i, L=None): assert results == [i + 6 for i in range(10)] -@gen_cluster(client=True) -async def test_recreate_error_delayed(c, s, a, b): - x0 = delayed(dec)(2) - y0 = delayed(dec)(1) - x = delayed(div)(1, x0) - y = delayed(div)(1, y0) - tot = delayed(sum)(x, y) - - f = c.compute(tot) - - assert f.status == "pending" - - error_f = await c._get_errored_future(f) - function, args, kwargs = await c._get_components_from_future(error_f) - assert f.status == "error" - assert function.__name__ == "div" - assert args == (1, 0) - with pytest.raises(ZeroDivisionError): - function(*args, **kwargs) - - -@gen_cluster(client=True) -async def test_recreate_error_futures(c, s, a, b): - x0 = c.submit(dec, 2) - y0 = c.submit(dec, 1) - x = c.submit(div, 1, x0) - y = c.submit(div, 1, y0) - tot = c.submit(sum, x, y) - f = c.compute(tot) - - assert f.status == "pending" - - error_f = await c._get_errored_future(f) - function, args, kwargs = await c._get_components_from_future(error_f) - assert f.status == "error" - assert function.__name__ == "div" - assert args == (1, 0) - with pytest.raises(ZeroDivisionError): - function(*args, **kwargs) - - -@gen_cluster(client=True) -async def test_recreate_error_collection(c, s, a, b): - pd = pytest.importorskip("pandas") - dd = pytest.importorskip("dask.dataframe") - - b = db.range(10, npartitions=4) - b = b.map(lambda x: 1 / x) - b = b.persist() - f = c.compute(b) - - error_f = await c._get_errored_future(f) - function, args, kwargs = await c._get_components_from_future(error_f) - with pytest.raises(ZeroDivisionError): - function(*args, **kwargs) - - df = dd.from_pandas(pd.DataFrame({"a": [0, 1, 2, 3, 4]}), chunksize=2) - - def make_err(x): - # because pandas would happily work with NaN - if x == 0: - raise ValueError - return x - - df2 = df.a.map(make_err, meta=df.a) - f = c.compute(df2) - error_f = await c._get_errored_future(f) - function, args, kwargs = await c._get_components_from_future(error_f) - with pytest.raises(ValueError): - function(*args, **kwargs) - - # with persist - df3 = c.persist(df2) - error_f = await c._get_errored_future(df3) - function, args, kwargs = await c._get_components_from_future(error_f) - with pytest.raises(ValueError): - function(*args, **kwargs) - - -@gen_cluster(client=True) -async def test_recreate_error_array(c, s, a, b): - pytest.importorskip("numpy") - da = pytest.importorskip("dask.array") - pytest.importorskip("scipy") - z = (da.linalg.inv(da.zeros((10, 10), chunks=10)) + 1).sum() - zz = z.persist() - error_f = await c._get_errored_future(zz) - function, args, kwargs = await c._get_components_from_future(error_f) - assert "0.,0.,0." in str(args).replace(" ", "") # args contain actual arrays - - def test_recreate_error_sync(c): x0 = c.submit(dec, 2) y0 = c.submit(dec, 1) @@ -4834,97 +4747,6 @@ def test_recreate_error_not_error(c): c.recreate_error_locally(f) -@gen_cluster(client=True) -async def test_recreate_task_delayed(c, s, a, b): - x0 = delayed(dec)(2) - y0 = delayed(dec)(2) - x = delayed(div)(1, x0) - y = delayed(div)(1, y0) - tot = delayed(sum)([x, y]) - - f = c.compute(tot) - - assert f.status == "pending" - - function, args, kwargs = await c._get_components_from_future(f) - assert f.status == "finished" - assert function.__name__ == "sum" - assert args == ([1, 1],) - assert function(*args, **kwargs) == 2 - - -@gen_cluster(client=True) -async def test_recreate_task_futures(c, s, a, b): - x0 = c.submit(dec, 2) - y0 = c.submit(dec, 2) - x = c.submit(div, 1, x0) - y = c.submit(div, 1, y0) - tot = c.submit(sum, [x, y]) - f = c.compute(tot) - - assert f.status == "pending" - - function, args, kwargs = await c._get_components_from_future(f) - assert f.status == "finished" - assert function.__name__ == "sum" - assert args == ([1, 1],) - assert function(*args, **kwargs) == 2 - - -@gen_cluster(client=True) -async def test_recreate_task_collection(c, s, a, b): - pd = pytest.importorskip("pandas") - dd = pytest.importorskip("dask.dataframe") - - b = db.range(10, npartitions=4) - b = b.map(lambda x: int(3628800 / (x + 1))) - b = b.persist() - f = c.compute(b) - - function, args, kwargs = await c._get_components_from_future(f) - assert function(*args, **kwargs) == [ - 3628800, - 1814400, - 1209600, - 907200, - 725760, - 604800, - 518400, - 453600, - 403200, - 362880, - ] - - df = dd.from_pandas(pd.DataFrame({"a": [0, 1, 2, 3, 4]}), chunksize=2) - - df2 = df.a.map(inc, meta=df.a) - f = c.compute(df2) - - function, args, kwargs = await c._get_components_from_future(f) - expected = pd.DataFrame({"a": [1, 2, 3, 4, 5]})["a"] - assert function(*args, **kwargs).equals(expected) - - # with persist - df3 = c.persist(df2) - # recreate_task_locally only works with futures - with pytest.raises(TypeError, match="key"): - function, args, kwargs = await c._get_components_from_future(df3) - - f = c.compute(df3) - function, args, kwargs = await c._get_components_from_future(f) - assert function(*args, **kwargs).equals(expected) - - -@gen_cluster(client=True) -async def test_recreate_task_array(c, s, a, b): - pytest.importorskip("numpy") - da = pytest.importorskip("dask.array") - z = (da.zeros((10, 10), chunks=10) + 1).sum() - f = c.compute(z) - function, args, kwargs = await c._get_components_from_future(f) - assert function(*args, **kwargs) == 100 - - def test_recreate_task_sync(c): x0 = c.submit(dec, 2) y0 = c.submit(dec, 2) @@ -5111,28 +4933,30 @@ def __setstate__(self, state): @gen_cluster(client=True) -async def test_robust_undeserializable_function(c, s, a, b): - class Foo: - def __getstate__(self): - return 1 +async def test_robust_undeserializable_function(c, s, a, b, monkeypatch): + with no_function_cache(): - def __setstate__(self, state): - raise MyException("hello") + class Foo: + def __getstate__(self): + return 1 - def __call__(self, *args): - return 1 + def __setstate__(self, state): + raise MyException("hello") - future = c.submit(Foo(), 1) - await wait(future) - assert future.status == "error" - with raises_with_cause(RuntimeError, "deserialization", MyException, "hello"): - await future + def __call__(self, *args): + return 1 - futures = c.map(inc, range(10)) - results = await c.gather(futures) + future = c.submit(Foo(), 1) + await wait(future) + assert future.status == "error" + with raises_with_cause(RuntimeError, "deserialization", MyException, "hello"): + await future - assert results == list(map(inc, range(10))) - assert a.data and b.data + futures = c.map(inc, range(10)) + results = await c.gather(futures) + + assert results == list(map(inc, range(10))) + assert a.data and b.data @gen_cluster(client=True) @@ -6221,7 +6045,7 @@ async def test_mixing_clients_same_scheduler(s, a, b): @gen_cluster() async def test_mixing_clients_different_scheduler(s, a, b): async with ( - Scheduler(port=open_port()) as s2, + Scheduler(port=open_port(), dashboard_address=":0") as s2, Worker(s2.address) as w1, Client(s.address, asynchronous=True) as c1, Client(s2.address, asynchronous=True) as c2, @@ -6265,7 +6089,7 @@ async def test_map_large_kwargs_in_graph(c, s, a, b): await asyncio.sleep(0.01) assert len(s.tasks) == 101 - assert any(k.startswith("ndarray") for k in s.tasks) + assert sum(sizeof(ts.run_spec) > sizeof(x) for ts in s.tasks.values()) == 1 @gen_cluster(client=True) @@ -6341,14 +6165,21 @@ async def test_profile_bokeh(c, s, a, b): assert os.path.exists(fn) -@gen_cluster(client=True) -async def test_get_mix_futures_and_SubgraphCallable(c, s, a, b): +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_get_mix_futures_and_SubgraphCallable(c, s, a): future = c.submit(add, 1, 2) subgraph = SubgraphCallable( - {"_2": (add, "_0", "_1"), "_3": (add, future, "_2")}, "_3", ("_0", "_1") + {"_2": (add, "_0", "_1"), "_3": (add, future, "_2")}, + "_3", + ("_0", "_1"), ) - dsk = {"a": 1, "b": 2, "c": (subgraph, "a", "b"), "d": (subgraph, "c", "b")} + dsk = { + "a": 1, + "b": 2, + "c": (subgraph, "a", "b"), + "d": (subgraph, "c", "b"), + } future2 = c.get(dsk, "d", sync=False) result = await future2 @@ -7993,9 +7824,9 @@ async def test_dump_cluster_state_exclude_default(c, s, a, b, tmp_path): futs = c.map(inc, range(10)) while len(s.tasks) != len(futs): await asyncio.sleep(0.01) - excluded_by_default = [ - "run_spec", - ] + # GraphNode / callables are never included + always_excluded = ["run_spec"] + exclude = ["state"] filename = tmp_path / "foo" await c.dump_cluster_state( @@ -8010,20 +7841,22 @@ async def test_dump_cluster_state_exclude_default(c, s, a, b, tmp_path): assert len(state["workers"]) == len(s.workers) for worker_dump in state["workers"].values(): for k, task_dump in worker_dump["tasks"].items(): - assert not any(blocked in task_dump for blocked in excluded_by_default) + assert not any(blocked in task_dump for blocked in always_excluded) + assert any(blocked in task_dump for blocked in exclude) assert k in s.tasks assert "scheduler" in state assert "tasks" in state["scheduler"] tasks = state["scheduler"]["tasks"] assert len(tasks) == len(futs) for k, task_dump in tasks.items(): - assert not any(blocked in task_dump for blocked in excluded_by_default) + assert not any(blocked in task_dump for blocked in always_excluded) + assert any(blocked in task_dump for blocked in exclude) assert k in s.tasks await c.dump_cluster_state( filename=filename, format="yaml", - exclude=(), + exclude=exclude, ) with open(f"{filename}.yaml") as fd: @@ -8033,14 +7866,16 @@ async def test_dump_cluster_state_exclude_default(c, s, a, b, tmp_path): assert len(state["workers"]) == len(s.workers) for worker_dump in state["workers"].values(): for k, task_dump in worker_dump["tasks"].items(): - assert all(blocked in task_dump for blocked in excluded_by_default) + assert not any(blocked in task_dump for blocked in always_excluded) + assert not any(blocked in task_dump for blocked in exclude) assert k in s.tasks assert "scheduler" in state assert "tasks" in state["scheduler"] tasks = state["scheduler"]["tasks"] assert len(tasks) == len(futs) for k, task_dump in tasks.items(): - assert all(blocked in task_dump for blocked in excluded_by_default) + assert not any(blocked in task_dump for blocked in always_excluded) + assert not any(blocked in task_dump for blocked in exclude) assert k in s.tasks diff --git a/distributed/tests/test_jupyter.py b/distributed/tests/test_jupyter.py index 3f45678f81..1d195494c6 100644 --- a/distributed/tests/test_jupyter.py +++ b/distributed/tests/test_jupyter.py @@ -36,7 +36,7 @@ @gen_test() async def test_jupyter_server(): - async with Scheduler(jupyter=True) as s: + async with Scheduler(jupyter=True, dashboard_address=":0") as s: http_client = AsyncHTTPClient() response = await http_client.fetch( f"http://localhost:{s.http_server.port}/jupyter/api/status" @@ -66,7 +66,7 @@ def test_jupyter_cli(loop): @gen_test() async def test_jupyter_idle_timeout(): "An active Jupyter session should prevent idle timeout" - async with Scheduler(jupyter=True, idle_timeout=0.2) as s: + async with Scheduler(jupyter=True, idle_timeout=0.2, dashboard_address=":0") as s: web_app = s._jupyter_server_application.web_app # Jupyter offers a place for extensions to provide updates on their last-active @@ -93,7 +93,7 @@ async def test_jupyter_idle_timeout(): @gen_test() async def test_jupyter_idle_timeout_returned(): "`check_idle` should return the last Jupyter idle time. Used in dask-kubernetes." - async with Scheduler(jupyter=True) as s: + async with Scheduler(jupyter=True, dashboard_address=":0") as s: web_app = s._jupyter_server_application.web_app extension_last_activty = web_app.settings["last_activity_times"] diff --git a/distributed/tests/test_preload.py b/distributed/tests/test_preload.py index 5a003f9fa8..1d9279b524 100644 --- a/distributed/tests/test_preload.py +++ b/distributed/tests/test_preload.py @@ -186,7 +186,9 @@ async def test_web_preload(): captured_logger("distributed.preloading") as log, ): async with Scheduler( - host="localhost", preload=["http://example.com/preload"] + host="localhost", + preload=["http://example.com/preload"], + dashboard_address=":0", ) as s: assert s.foo == 1 assert ( diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 0369f6c53b..c07bc2a0a6 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -29,7 +29,6 @@ from dask.utils import parse_timedelta, tmpfile, typename from distributed import ( - CancelledError, Client, Event, Lock, @@ -4177,16 +4176,15 @@ async def test_transition_counter(c, s, a): assert a.state.transition_counter > 1 -@pytest.mark.slow @gen_cluster(client=True) async def test_transition_counter_max_scheduler(c, s, a, b): # This is set by @gen_cluster; it's False in production assert s.transition_counter_max > 0 s.transition_counter_max = 1 with captured_logger("distributed.scheduler") as logger: - with pytest.raises(CancelledError): + with pytest.raises(AssertionError): await c.submit(inc, 2) - assert s.transition_counter > 1 + assert s.transition_counter == 1 with pytest.raises(AssertionError): s.validate_state() assert "transition_counter_max" in logger.getvalue() @@ -4959,8 +4957,8 @@ def _match(event): _, msg = event return ( isinstance(msg, dict) - and msg.get("action", None) == "update_graph" - and msg["key-collisions"] > 0 + and msg.get("action", None) == "update-graph" + and msg["metrics"]["key_collisions"] > 0 ) def handler(ev): @@ -4968,7 +4966,7 @@ def handler(ev): nonlocal seen seen = True - c.subscribe_topic("all", handler) + c.subscribe_topic("scheduler", handler) x1 = c.submit(inc, 1, key="x1") y_old = c.submit(inc, x1, key="y") diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index 1182397d99..8fbfce7cae 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -36,7 +36,6 @@ RateLimiterFilter, TimeoutError, TupleComparable, - _maybe_complex, ensure_ip, ensure_memoryview, format_dashboard_link, @@ -234,15 +233,6 @@ def c(x): assert type(tb).__name__ == "traceback" -def test_maybe_complex(): - assert not _maybe_complex(1) - assert not _maybe_complex("x") - assert _maybe_complex((inc, 1)) - assert _maybe_complex([(inc, 1)]) - assert _maybe_complex([(inc, 1)]) - assert _maybe_complex({"x": (inc, 1)}) - - def test_read_block(): delimiter = b"\n" data = delimiter.join([b"123", b"456", b"789"]) diff --git a/distributed/tests/test_utils_comm.py b/distributed/tests/test_utils_comm.py index 94b1c33f3a..dbfe6c8a5b 100644 --- a/distributed/tests/test_utils_comm.py +++ b/distributed/tests/test_utils_comm.py @@ -6,6 +6,7 @@ import pytest +from dask._task_spec import TaskRef from dask.optimization import SubgraphCallable from distributed import wait @@ -13,8 +14,6 @@ from distributed.config import get_loop_factory from distributed.core import ConnectionPool, Status from distributed.utils_comm import ( - DoNotUnpack, - WrappedKey, gather_from_workers, pack_data, retry, @@ -232,53 +231,33 @@ async def f(): def test_unpack_remotedata(): - def assert_eq(keys1: set[WrappedKey], keys2: set[WrappedKey]) -> None: + def assert_eq(keys1: set[TaskRef], keys2: set[TaskRef]) -> None: if len(keys1) != len(keys2): assert False if not keys1: assert True - if not all(isinstance(k, WrappedKey) for k in keys1 & keys2): + if not all(isinstance(k, TaskRef) for k in keys1 & keys2): assert False assert sorted([k.key for k in keys1]) == sorted([k.key for k in keys2]) assert unpack_remotedata(1) == (1, set()) assert unpack_remotedata(()) == ((), set()) - res, keys = unpack_remotedata(WrappedKey("mykey")) + res, keys = unpack_remotedata(TaskRef("mykey")) assert res == "mykey" - assert_eq(keys, {WrappedKey("mykey")}) + assert_eq(keys, {TaskRef("mykey")}) # Check unpack of SC that contains a wrapped key - sc = SubgraphCallable({"key": (WrappedKey("data"),)}, outkey="key", inkeys=["arg1"]) + sc = SubgraphCallable({"key": (TaskRef("data"),)}, outkey="key", inkeys=["arg1"]) dsk = (sc, "arg1") res, keys = unpack_remotedata(dsk) assert res[0] != sc # Notice, the first item (the SC) has been changed assert res[1:] == ("arg1", "data") - assert_eq(keys, {WrappedKey("data")}) + assert_eq(keys, {TaskRef("data")}) # Check unpack of SC when it takes a wrapped key as argument - sc = SubgraphCallable({"key": ("arg1",)}, outkey="key", inkeys=[WrappedKey("arg1")]) + sc = SubgraphCallable({"key": ("arg1",)}, outkey="key", inkeys=[TaskRef("arg1")]) dsk = (sc, "arg1") res, keys = unpack_remotedata(dsk) assert res == (sc, "arg1") # Notice, the first item (the SC) has NOT been changed assert_eq(keys, set()) - - -def test_unpack_remotedata_custom_tuple(): - # We don't want to recurse into custom tuples. This is used as a sentinel to - # avoid recursion for performance reasons if we know that there are no - # nested futures. This test case is not how this feature should be used in - # practice. - - akey = WrappedKey("a") - - ordinary_tuple = (1, 2, akey) - dont_recurse = DoNotUnpack(ordinary_tuple) - - res, keys = unpack_remotedata(ordinary_tuple) - assert res is not ordinary_tuple - assert any(left != right for left, right in zip(ordinary_tuple, res)) - assert keys == {akey} - res, keys = unpack_remotedata(dont_recurse) - assert not keys - assert res is dont_recurse diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 1772ca3a94..1fe5ec58a0 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -423,7 +423,7 @@ def test_computetask_dummy(): nbytes={}, priority=(0,), duration=1.0, - run_spec=ComputeTaskEvent.dummy_runspec(), + run_spec=ComputeTaskEvent.dummy_runspec("x"), resource_restrictions={}, actor=False, annotations={}, diff --git a/distributed/utils.py b/distributed/utils.py index e360667bc9..7a127621eb 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -67,7 +67,6 @@ from tornado.ioloop import IOLoop import dask -from dask import istask from dask.utils import ensure_bytes as _ensure_bytes from dask.utils import key_split from dask.utils import parse_timedelta as _parse_timedelta @@ -922,6 +921,7 @@ def get_traceback(): os.path.join("distributed", "scheduler"), os.path.join("tornado", "gen.py"), os.path.join("concurrent", "futures"), + os.path.join("dask", "_task_spec"), ] while exc_traceback and any( b in exc_traceback.tb_frame.f_code.co_filename for b in bad @@ -941,17 +941,6 @@ def truncate_exception(e, n=10000): return e -def _maybe_complex(task): - """Possibly contains a nested task""" - return ( - istask(task) - or type(task) is list - and any(map(_maybe_complex, task)) - or type(task) is dict - and any(map(_maybe_complex, task.values())) - ) - - def seek_delimiter(file, delimiter, blocksize): """Seek current file to next byte after a delimiter bytestring diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index d37a808363..b8ca49b7fc 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -12,6 +12,7 @@ from tlz import drop, groupby, merge import dask.config +from dask._task_spec import TaskRef from dask.optimization import SubgraphCallable from dask.typing import Key from dask.utils import is_namedtuple_instance, parse_timedelta @@ -22,6 +23,10 @@ logger = logging.getLogger(__name__) +# Backwards compat +WrappedKey = TaskRef + + async def gather_from_workers( who_has: Mapping[Key, Collection[str]], rpc: ConnectionPool, @@ -130,24 +135,6 @@ async def gather_from_workers( return data, [], failed_keys, list(missing_workers) -class WrappedKey: - """Interface for a key in a dask graph. - - Subclasses must have .key attribute that refers to a key in a dask graph. - - Sometimes we want to associate metadata to keys in a dask graph. For - example we might know that that key lives on a particular machine or can - only be accessed in a certain way. Schedulers may have particular needs - that can only be addressed by additional metadata. - """ - - def __init__(self, key): - self.key = key - - def __repr__(self): - return f"{type(self).__name__}('{self.key}')" - - _round_robin_counter = [0] @@ -202,7 +189,7 @@ def _namedtuple_packing(o: Any, handler: Callable[..., Any]) -> Any: def _unpack_remotedata_inner( - o: Any, byte_keys: bool, found_futures: set[WrappedKey] + o: Any, byte_keys: bool, found_futures: set[TaskRef] ) -> Any: """Inner implementation of `unpack_remotedata` that adds found wrapped keys to `found_futures`""" @@ -212,7 +199,7 @@ def _unpack_remotedata_inner( return o if type(o[0]) is SubgraphCallable: # Unpack futures within the arguments of the subgraph callable - futures: set[WrappedKey] = set() + futures: set[TaskRef] = set() args = tuple(_unpack_remotedata_inner(i, byte_keys, futures) for i in o[1:]) found_futures.update(futures) @@ -261,7 +248,7 @@ def _unpack_remotedata_inner( } else: return o - elif issubclass(typ, WrappedKey): # TODO use type is Future + elif issubclass(typ, TaskRef): # TODO use type is Future k = o.key found_futures.add(o) return k @@ -269,41 +256,34 @@ def _unpack_remotedata_inner( return o -class DoNotUnpack(tuple): - """A tuple sublass to indicate that we should not unpack its contents - - See also unpack_remotedata - """ - - def unpack_remotedata(o: Any, byte_keys: bool = False) -> tuple[Any, set]: - """Unpack WrappedKey objects from collection + """Unpack TaskRef objects from collection - Returns original collection and set of all found WrappedKey objects + Returns original collection and set of all found TaskRef objects Examples -------- - >>> rd = WrappedKey('mykey') + >>> rd = TaskRef('mykey') >>> unpack_remotedata(1) (1, set()) >>> unpack_remotedata(()) ((), set()) >>> unpack_remotedata(rd) - ('mykey', {WrappedKey('mykey')}) + ('mykey', {TaskRef('mykey')}) >>> unpack_remotedata([1, rd]) - ([1, 'mykey'], {WrappedKey('mykey')}) + ([1, 'mykey'], {TaskRef('mykey')}) >>> unpack_remotedata({1: rd}) - ({1: 'mykey'}, {WrappedKey('mykey')}) + ({1: 'mykey'}, {TaskRef('mykey')}) >>> unpack_remotedata({1: [rd]}) - ({1: ['mykey']}, {WrappedKey('mykey')}) + ({1: ['mykey']}, {TaskRef('mykey')}) Use the ``byte_keys=True`` keyword to force string keys - >>> rd = WrappedKey(('x', 1)) + >>> rd = TaskRef(('x', 1)) >>> unpack_remotedata(rd, byte_keys=True) - ("('x', 1)", {WrappedKey('('x', 1)')}) + ("('x', 1)", {TaskRef('('x', 1)')}) """ - found_futures: set[WrappedKey] = set() + found_futures: set[TaskRef] = set() return _unpack_remotedata_inner(o, byte_keys, found_futures), found_futures diff --git a/distributed/worker.py b/distributed/worker.py index 7e3fecb9b2..1dcfa9ea8d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -45,11 +45,10 @@ from tornado.ioloop import IOLoop import dask -from dask.core import istask +from dask._task_spec import GraphNode from dask.system import CPU_COUNT from dask.typing import Key from dask.utils import ( - apply, format_bytes, funcname, key_split, @@ -98,7 +97,6 @@ from distributed.threadpoolexecutor import secede as tpe_secede from distributed.utils import ( TimeoutError, - _maybe_complex, get_ip, has_arg, in_async_call, @@ -114,7 +112,7 @@ thread_state, wait_for, ) -from distributed.utils_comm import gather_from_workers, pack_data, retry_operation +from distributed.utils_comm import gather_from_workers, retry_operation from distributed.versions import get_versions from distributed.worker_memory import ( DeprecatedMemoryManagerAttribute, @@ -158,7 +156,6 @@ # Circular imports from distributed.client import Client from distributed.nanny import Nanny - from distributed.scheduler import T_runspec P = ParamSpec("P") T = TypeVar("T") @@ -183,6 +180,27 @@ } +class RunTaskSuccess(OKMessage): + op: Literal["task-finished"] + result: object + nbytes: int + type: type + start: float + stop: float + thread: int + + +class RunTaskFailure(ErrorMessage): + op: Literal["task-erred"] + result: object + nbytes: int + type: type + start: float + stop: float + thread: int + actual_exception: BaseException | Exception + + class GetDataBusy(TypedDict): status: Literal["busy"] @@ -2177,7 +2195,7 @@ async def actor_execute( elif separate_thread: result = await self.loop.run_in_executor( self.executors["actor"], - apply_function_actor, + _run_actor, func, args, kwargs, @@ -2227,8 +2245,23 @@ async def execute(self, key: Key, *, stimulus_id: str) -> StateMachineEvent: assert ts.state in ("executing", "cancelled", "resumed"), ts assert ts.run_spec is not None - function, args, kwargs = ts.run_spec - args2, kwargs2 = self._prepare_args_for_execution(ts, args, kwargs) + start = time() + data: dict[Key, Any] = {} + for dep in ts.dependencies: + dkey = dep.key + actors = self.state.actors + if actors and dkey in actors: + from distributed.actor import Actor # TODO: create local actor + + data[dkey] = Actor(type(actors[dkey]), self.address, dkey, self) + else: + data[dkey] = self.data[dkey] + + stop = time() + if stop - start > 0.005: + ts.startstops.append( + {"action": "disk-read", "start": start, "stop": stop} + ) assert ts.annotations is not None executor = ts.annotations.get("executor", "default") @@ -2249,16 +2282,16 @@ async def execute(self, key: Key, *, stimulus_id: str) -> StateMachineEvent: else contextlib.nullcontext() ) span_ctx.__enter__() - + run_spec = ts.run_spec try: ts.start_time = time() - if iscoroutinefunction(function): + + if ts.run_spec.is_coro: token = _worker_cvar.set(self) try: - result = await apply_function_async( - function, - args2, - kwargs2, + result = await _run_task_async( + ts.run_spec, + data, self.scheduler_delay, ) finally: @@ -2268,15 +2301,14 @@ async def execute(self, key: Key, *, stimulus_id: str) -> StateMachineEvent: # e.g. thread synchronization overhead only, since thread-noncpu and # thread-cpu inside the thread detract from it. However, it may # become substantial in case of misalignment between the size of the - # thread pool and the number of running tasks in the worker state + # thread pool and the number of running tasks in the worker stater # machine (e.g. https://github.com/dask/distributed/issues/5882) with context_meter.meter("executor"): result = await run_in_executor_with_context( e, - apply_function, - function, - args2, - kwargs2, + _run_task, + ts.run_spec, + data, self.execution_state, key, self.active_threads, @@ -2290,10 +2322,9 @@ async def execute(self, key: Key, *, stimulus_id: str) -> StateMachineEvent: with context_meter.meter("executor"): result = await self.loop.run_in_executor( e, - apply_function_simple, - function, - args2, - kwargs2, + _run_task_simple, + ts.run_spec, + data, self.scheduler_delay, ) finally: @@ -2318,13 +2349,13 @@ async def execute(self, key: Key, *, stimulus_id: str) -> StateMachineEvent: stimulus_id=f"task-finished-{time()}", ) - task_exc = result["actual-exception"] + task_exc = result["actual_exception"] if isinstance(task_exc, Reschedule): return RescheduleEvent(key=ts.key, stimulus_id=f"reschedule-{time()}") if ( self.status == Status.closing and isinstance(task_exc, asyncio.CancelledError) - and iscoroutinefunction(function) + and run_spec.is_coro ): # `Worker.cancel` will cause async user tasks to raise `CancelledError`. # Since we cancelled those tasks, we shouldn't treat them as failures. @@ -2342,16 +2373,12 @@ async def execute(self, key: Key, *, stimulus_id: str) -> StateMachineEvent: "Compute Failed\n" "Key: %s\n" "State: %s\n" - "Function: %s\n" - "args: %s\n" - "kwargs: %s\n" + "Task: %s\n" "Exception: %r\n" "Traceback: %r\n", key, ts.state, - str(funcname(function))[:1000], - convert_args_to_str(args2, max_len=1000), - convert_kwargs_to_str(kwargs2, max_len=1000), + repr(run_spec)[:1000], result["exception_text"], result["traceback_text"], ) @@ -2385,27 +2412,6 @@ async def execute(self, key: Key, *, stimulus_id: str) -> StateMachineEvent: stimulus_id=f"execute-unknown-error-{time()}", ) - def _prepare_args_for_execution( - self, ts: TaskState, args: tuple, kwargs: dict[str, Any] - ) -> tuple[tuple[object, ...], dict[str, object]]: - start = time() - data = {} - for dep in ts.dependencies: - k = dep.key - try: - data[k] = self.data[k] - except KeyError: - from distributed.actor import Actor # TODO: create local actor - - data[k] = Actor(type(self.state.actors[k]), self.address, k, self) - args2 = pack_data(args, data, key_types=(bytes, str, tuple)) - kwargs2 = pack_data(kwargs, data, key_types=(bytes, str, tuple)) - stop = time() - if stop - start > 0.005: - ts.startstops.append({"action": "disk-read", "start": start, "stop": stop}) - - return args2, kwargs2 - ################## # Administrative # ################## @@ -2891,34 +2897,6 @@ async def get_data_from_worker( rpc.reuse(worker, comm) -def _normalize_task(task: Any) -> T_runspec: - if istask(task): - if task[0] is apply and not any(map(_maybe_complex, task[2:])): - return task[1], task[2], task[3] if len(task) == 4 else {} - elif not any(map(_maybe_complex, task[1:])): - return task[0], task[1:], {} - - return execute_task, (task,), {} - - -def execute_task(task): - """Evaluate a nested task - - >>> inc = lambda x: x + 1 - >>> execute_task((inc, 1)) - 2 - >>> execute_task((sum, [1, 2, (inc, 3)])) - 7 - """ - if istask(task): - func, args = task[0], task[1:] - return func(*map(execute_task, args)) - elif isinstance(task, list): - return list(map(execute_task, task)) - else: - return task - - cache_dumps: LRU[Callable[..., Any], bytes] = LRU(maxsize=100) _cache_lock = threading.Lock() @@ -2939,16 +2917,15 @@ def dumps_function(func) -> bytes: return result -def apply_function( - function, - args, - kwargs, - execution_state, - key, - active_threads, - active_threads_lock, - time_delay, -): +def _run_task( + task: GraphNode, + data: dict, + execution_state: dict, + key: Key, + active_threads: dict, + active_threads_lock: threading.Lock, + time_delay: float, +) -> RunTaskSuccess | RunTaskFailure: """Run a function, collect information Returns @@ -2965,7 +2942,7 @@ def apply_function( ): token = _worker_cvar.set(execution_state["worker"]) try: - msg = apply_function_simple(function, args, kwargs, time_delay) + msg = _run_task_simple(task, data, time_delay) finally: _worker_cvar.reset(token) @@ -2974,12 +2951,11 @@ def apply_function( return msg -def apply_function_simple( - function, - args, - kwargs, - time_delay, -): +def _run_task_simple( + task: GraphNode, + data: dict, + time_delay: float, +) -> RunTaskSuccess | RunTaskFailure: """Run a function, collect information Returns @@ -3002,7 +2978,7 @@ def apply_function_simple( context_meter.meter("thread-cpu", func=thread_time), ): try: - result = function(*args, **kwargs) + result = task(data) except (SystemExit, KeyboardInterrupt): # Special-case these, just like asyncio does all over the place. They will # pass through `fail_hard` and `_handle_stimulus_from_task`, and eventually @@ -3015,11 +2991,11 @@ def apply_function_simple( # Users _shouldn't_ use `BaseException`s, but if they do, we can assume they # aren't a reason to shut down the whole system (since we allow the # system-shutting-down `SystemExit` and `KeyboardInterrupt` to pass through) - msg = error_message(e) + msg: RunTaskFailure = error_message(e) # type: ignore msg["op"] = "task-erred" - msg["actual-exception"] = e + msg["actual_exception"] = e else: - msg = { + msg: RunTaskSuccess = { # type: ignore "op": "task-finished", "status": "OK", "result": result, @@ -3033,12 +3009,11 @@ def apply_function_simple( return msg -async def apply_function_async( - function, - args, - kwargs, - time_delay, -): +async def _run_task_async( + task: GraphNode, + data: dict, + time_delay: float, +) -> RunTaskSuccess | RunTaskFailure: """Run a function, collect information Returns @@ -3047,7 +3022,7 @@ async def apply_function_async( """ with context_meter.meter("thread-noncpu", func=time) as m: try: - result = await function(*args, **kwargs) + result = await task(data) except (SystemExit, KeyboardInterrupt): # Special-case these, just like asyncio does all over the place. They will # pass through `fail_hard` and `_handle_stimulus_from_task`, and eventually @@ -3062,11 +3037,11 @@ async def apply_function_async( # Users _shouldn't_ use `BaseException`s, but if they do, we can assume they # aren't a reason to shut down the whole system (since we allow the # system-shutting-down `SystemExit` and `KeyboardInterrupt` to pass through) - msg = error_message(e) + msg: RunTaskFailure = error_message(e) # type: ignore msg["op"] = "task-erred" - msg["actual-exception"] = e + msg["actual_exception"] = e else: - msg = { + msg: RunTaskSuccess = { # type: ignore "op": "task-finished", "status": "OK", "result": result, @@ -3080,9 +3055,15 @@ async def apply_function_async( return msg -def apply_function_actor( - function, args, kwargs, execution_state, key, active_threads, active_threads_lock -): +def _run_actor( + func: Callable, + args: tuple, + kwargs: dict, + execution_state: dict, + key: Key, + active_threads: dict, + active_threads_lock: threading.Lock, +) -> Any: """Run a function, collect information Returns @@ -3102,7 +3083,7 @@ def apply_function_actor( ): token = _worker_cvar.set(execution_state["worker"]) try: - result = function(*args, **kwargs) + result = func(*args, **kwargs) finally: _worker_cvar.reset(token) diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index caf6ed2ac7..3507fb9379 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -30,6 +30,7 @@ from tlz import peekn import dask +from dask._task_spec import Task from dask.typing import Key from dask.utils import key_split, parse_bytes, typename @@ -39,7 +40,7 @@ from distributed.core import ErrorMessage, error_message from distributed.metrics import DelayedMetricsLedger, monotonic, time from distributed.protocol import pickle -from distributed.protocol.serialize import Serialize, ToPickle +from distributed.protocol.serialize import Serialize from distributed.sizeof import safe_sizeof as sizeof from distributed.utils import recursive_to_dict @@ -748,10 +749,6 @@ def __post_init__(self) -> None: # Fixes after msgpack decode if isinstance(self.priority, list): # type: ignore[unreachable] self.priority = tuple(self.priority) # type: ignore[unreachable] - if isinstance(self.run_spec, ToPickle): - # FIXME Sometimes the protocol is not unpacking this - # E.g. distributed/tests/test_client.py::test_async_with - self.run_spec = self.run_spec.data # type: ignore[unreachable] def _to_dict(self, *, exclude: Container[str] = ()) -> dict: return StateMachineEvent._to_dict(self._clean(), exclude=exclude) @@ -774,8 +771,8 @@ def _f(cls) -> None: return # pragma: nocover @classmethod - def dummy_runspec(cls) -> tuple[Callable, tuple, dict]: - return (cls._f, (), {}) + def dummy_runspec(cls, key: Key) -> Task: + return Task(key, cls._f) @staticmethod def dummy( @@ -801,7 +798,7 @@ def dummy( nbytes=nbytes or {k: 1 for k in who_has or ()}, priority=priority, duration=duration, - run_spec=ComputeTaskEvent.dummy_runspec(), + run_spec=ComputeTaskEvent.dummy_runspec(key), resource_restrictions=resource_restrictions or {}, actor=actor, annotations=annotations or {},