Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce P2P transfer task overhead #8912

Merged
merged 9 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
from tornado.ioloop import IOLoop

import dask.config
from dask._task_spec import Task, _inline_recursively
from dask.core import flatten
from dask.sizeof import sizeof
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta

Expand Down Expand Up @@ -575,3 +577,61 @@ def p2p_barrier(id: ShuffleId, run_ids: list[int]) -> int:
raise
except Exception as e:
raise RuntimeError(f"P2P {id} failed during barrier phase") from e


class P2PBarrierTask(Task):
spec: ShuffleSpec

__slots__ = tuple(__annotations__)

def __init__(
self,
key: Any,
func: Callable[..., Any],
/,
*args: Any,
spec: ShuffleSpec,
**kwargs: Any,
):
self.spec = spec
super().__init__(key, func, *args, **kwargs)

def copy(self) -> P2PBarrierTask:
self.unpack()
assert self.func is not None
return P2PBarrierTask(
self.key, self.func, *self.args, spec=self.spec, **self.kwargs
)

def __sizeof__(self) -> int:
return super().__sizeof__() + sizeof(self.spec)

def __repr__(self) -> str:
return f"P2PBarrierTask({self.key!r})"

def inline(self, dsk: dict[Key, Any]) -> P2PBarrierTask:
self.unpack()
new_args = _inline_recursively(self.args, dsk)
new_kwargs = _inline_recursively(self.kwargs, dsk)
assert self.func is not None
return P2PBarrierTask(
self.key, self.func, *new_args, spec=self.spec, **new_kwargs
)

def __getstate__(self) -> dict[str, Any]:
state = super().__getstate__()
state["spec"] = self.spec
return state

def __setstate__(self, state: dict[str, Any]) -> None:
super().__setstate__(state)
self.spec = state["spec"]

def __eq__(self, value: object) -> bool:
if not isinstance(value, P2PBarrierTask):
return False
if not super().__eq__(value):
return False
if self.spec != value.spec:
return False
return True
117 changes: 74 additions & 43 deletions distributed/shuffle/_merge.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,37 @@
# mypy: ignore-errors
from __future__ import annotations

from collections.abc import Iterable, Sequence
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any

import dask
from dask._task_spec import GraphNode, Task, TaskRef
from dask.base import is_dask_collection
from dask.highlevelgraph import HighLevelGraph
from dask.layers import Layer
from dask.tokenize import tokenize
from dask.typing import Key

from distributed.shuffle._arrow import check_minimal_arrow_version
from distributed.shuffle._core import (
P2PBarrierTask,
ShuffleId,
barrier_key,
get_worker_plugin,
p2p_barrier,
)
from distributed.shuffle._shuffle import shuffle_transfer
from distributed.shuffle._shuffle import DataFrameShuffleSpec, shuffle_transfer

if TYPE_CHECKING:
import pandas as pd
from pandas._typing import IndexLabel, MergeHow, Suffixes

# TODO import from typing (requires Python >=3.10)
from typing_extensions import TypeAlias

from dask.dataframe.core import _Frame

_T_LowLevelGraph: TypeAlias = dict[Key, GraphNode]

_HASH_COLUMN_NAME = "__hash_partition"

Expand Down Expand Up @@ -148,21 +155,11 @@ def merge_transfer(
input: pd.DataFrame,
id: ShuffleId,
input_partition: int,
npartitions: int,
meta: pd.DataFrame,
parts_out: set[int],
disk: bool,
):
return shuffle_transfer(
input=input,
id=id,
input_partition=input_partition,
npartitions=npartitions,
column=_HASH_COLUMN_NAME,
meta=meta,
parts_out=parts_out,
disk=disk,
drop_column=True,
)


Expand Down Expand Up @@ -208,7 +205,7 @@ class HashJoinP2PLayer(Layer):
suffixes: Suffixes
indicator: bool
meta_output: pd.DataFrame
parts_out: Sequence[int]
parts_out: set[int]

name_input_left: str
meta_input_left: pd.DataFrame
Expand Down Expand Up @@ -241,7 +238,7 @@ def __init__(
how: MergeHow = "inner",
suffixes: Suffixes = ("_x", "_y"),
indicator: bool = False,
parts_out: Sequence | None = None,
parts_out: Iterable[int] | None = None,
annotations: dict | None = None,
) -> None:
check_minimal_arrow_version()
Expand All @@ -257,7 +254,10 @@ def __init__(
self.suffixes = suffixes
self.indicator = indicator
self.meta_output = meta_output
self.parts_out = parts_out or list(range(npartitions))
if parts_out:
self.parts_out = set(parts_out)
else:
self.parts_out = set(range(npartitions))
self.n_partitions_left = n_partitions_left
self.n_partitions_right = n_partitions_right
self.left_index = left_index
Expand Down Expand Up @@ -325,7 +325,7 @@ def _dict(self):
self._cached_dict = dsk
return self._cached_dict

def _cull(self, parts_out: Sequence[str]):
def _cull(self, parts_out: Iterable[int]):
return HashJoinP2PLayer(
name=self.name,
name_input_left=self.name_input_left,
Expand Down Expand Up @@ -365,7 +365,7 @@ def cull(self, keys: Iterable[str], all_keys: Any) -> tuple[HashJoinP2PLayer, di
else:
return self, culled_deps

def _construct_graph(self) -> dict[tuple | str, tuple]:
def _construct_graph(self) -> _T_LowLevelGraph:
token_left = tokenize(
# Include self.name to ensure that shuffle IDs are unique for individual
# merge operations. Reusing shuffles between merges is dangerous because of
Expand All @@ -375,6 +375,7 @@ def _construct_graph(self) -> dict[tuple | str, tuple]:
self.left_on,
self.left_index,
)
shuffle_id_left = ShuffleId(token_left)
token_right = tokenize(
# Include self.name to ensure that shuffle IDs are unique for individual
# merge operations. Reusing shuffles between merges is dangerous because of
Expand All @@ -384,50 +385,79 @@ def _construct_graph(self) -> dict[tuple | str, tuple]:
self.right_on,
self.right_index,
)
dsk: dict[tuple | str, tuple] = {}
shuffle_id_right = ShuffleId(token_right)
dsk: _T_LowLevelGraph = {}
name_left = "hash-join-transfer-" + token_left
name_right = "hash-join-transfer-" + token_right
transfer_keys_left = list()
transfer_keys_right = list()
for i in range(self.n_partitions_left):
transfer_keys_left.append((name_left, i))
dsk[(name_left, i)] = (
t = Task(
(name_left, i),
merge_transfer,
(self.name_input_left, i),
token_left,
TaskRef((self.name_input_left, i)),
shuffle_id_left,
i,
self.npartitions,
self.meta_input_left,
self.parts_out,
self.disk,
)
dsk[t.key] = t
transfer_keys_left.append(t.ref())

transfer_keys_right = list()
for i in range(self.n_partitions_right):
transfer_keys_right.append((name_right, i))
dsk[(name_right, i)] = (
t = Task(
(name_right, i),
merge_transfer,
(self.name_input_right, i),
token_right,
TaskRef((self.name_input_right, i)),
shuffle_id_right,
i,
self.npartitions,
self.meta_input_right,
self.parts_out,
self.disk,
)

_barrier_key_left = barrier_key(ShuffleId(token_left))
_barrier_key_right = barrier_key(ShuffleId(token_right))
dsk[_barrier_key_left] = (p2p_barrier, token_left, transfer_keys_left)
dsk[_barrier_key_right] = (p2p_barrier, token_right, transfer_keys_right)
dsk[t.key] = t
transfer_keys_right.append(t.ref())

_barrier_key_left = barrier_key(shuffle_id_left)
barrier_left = P2PBarrierTask(
_barrier_key_left,
p2p_barrier,
token_left,
transfer_keys_left,
spec=DataFrameShuffleSpec(
id=shuffle_id_left,
npartitions=self.npartitions,
column=_HASH_COLUMN_NAME,
meta=self.meta_input_left,
parts_out=self.parts_out,
disk=self.disk,
drop_column=True,
),
)
dsk[barrier_left.key] = barrier_left
_barrier_key_right = barrier_key(shuffle_id_right)
barrier_right = P2PBarrierTask(
_barrier_key_right,
p2p_barrier,
token_right,
transfer_keys_right,
spec=DataFrameShuffleSpec(
id=shuffle_id_right,
npartitions=self.npartitions,
column=_HASH_COLUMN_NAME,
meta=self.meta_input_right,
parts_out=self.parts_out,
disk=self.disk,
drop_column=True,
),
)
dsk[barrier_right.key] = barrier_right

name = self.name
for part_out in self.parts_out:
dsk[(name, part_out)] = (
t = Task(
(name, part_out),
merge_unpack,
token_left,
token_right,
part_out,
_barrier_key_left,
_barrier_key_right,
barrier_left.ref(),
barrier_right.ref(),
self.how,
self.left_on,
self.right_on,
Expand All @@ -437,4 +467,5 @@ def _construct_graph(self) -> dict[tuple | str, tuple]:
self.right_index,
self.indicator,
)
dsk[t.key] = t
return dsk
25 changes: 13 additions & 12 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,17 +132,18 @@
from distributed.metrics import context_meter
from distributed.shuffle._core import (
NDIndex,
P2PBarrierTask,
ShuffleId,
ShuffleRun,
ShuffleSpec,
barrier_key,
get_worker_plugin,
handle_transfer_errors,
handle_unpack_errors,
p2p_barrier,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._pickle import unpickle_bytestream
from distributed.shuffle._shuffle import barrier_key
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.sizeof import sizeof

Expand All @@ -164,15 +165,12 @@ def rechunk_transfer(
input: np.ndarray,
id: ShuffleId,
input_chunk: NDIndex,
new: ChunkedAxes,
old: ChunkedAxes,
disk: bool,
) -> int:
with handle_transfer_errors(id):
return get_worker_plugin().add_partition(
input,
partition_id=input_chunk,
spec=ArrayRechunkSpec(id=id, new=new, old=old, disk=disk),
id=id,
)


Expand Down Expand Up @@ -815,16 +813,19 @@ def partial_rechunk(
key,
rechunk_transfer,
input_key,
partial_token,
ShuffleId(partial_token),
partial_index,
partial_new,
partial_old,
disk,
)
transfer_keys.append(t.ref())

dsk[_barrier_key] = barrier = Task(
_barrier_key, p2p_barrier, partial_token, transfer_keys
dsk[_barrier_key] = barrier = P2PBarrierTask(
_barrier_key,
p2p_barrier,
partial_token,
transfer_keys,
spec=ArrayRechunkSpec(
id=ShuffleId(partial_token), new=partial_new, old=partial_old, disk=disk
),
)

new_partial_offset = tuple(axis.start for axis in ndpartial.new)
Expand All @@ -835,7 +836,7 @@ def partial_rechunk(
dsk[k] = Task(
k,
rechunk_unpack,
partial_token,
ShuffleId(partial_token),
partial_index,
barrier.ref(),
)
Expand Down
Loading
Loading