Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce P2P transfer task overhead #8912

Merged
merged 9 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
from tornado.ioloop import IOLoop

import dask.config
from dask._task_spec import Task, _inline_recursively
from dask.core import flatten
from dask.sizeof import sizeof
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta

Expand Down Expand Up @@ -575,3 +577,60 @@ def p2p_barrier(id: ShuffleId, run_ids: list[int]) -> int:
raise
except Exception as e:
raise RuntimeError(f"P2P {id} failed during barrier phase") from e


class P2PBarrierTask(Task):
spec: ShuffleSpec

__slots__ = tuple(__annotations__)

def __init__(
self,
key: Any,
func: Callable[..., Any],
/,
*args: Any,
spec: ShuffleSpec,
**kwargs: Any,
):
self.spec = spec
super().__init__(key, func, *args, **kwargs)

def copy(self):
self.unpack()
return P2PBarrierTask(
self.key, self.func, *self.args, spec=self.spec, **self.kwargs
)

def __sizeof__(self) -> int:
return super().__sizeof__() + sizeof(self.spec)

def __repr__(self) -> str:
return f"P2PBarrierTask({self.key!r})"

def inline(self, dsk) -> P2PBarrierTask:
self.unpack()
new_args = _inline_recursively(self.args, dsk)
new_kwargs = _inline_recursively(self.kwargs, dsk)
assert self.func is not None
return P2PBarrierTask(
self.key, self.func, *new_args, spec=self.spec, **new_kwargs
)

def __getstate__(self):
state = super().__getstate__()
state["spec"] = self.spec
return state

def __setstate__(self, state):
super().__setstate__(state)
self.spec = state["spec"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK with this solution. Just to be clear, though, when we talked the other day I was hoping to be able to implement a more robust version of this in the base class that doens't require the child classes to overwrite. If that's not easily possible, that's fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted, but I'd like to isolate changes in the base class from this PR. I'm already juggling two repos here and would like to avoid adding a third into the mix.


def __eq__(self, value: object) -> bool:
if not isinstance(value, P2PBarrierTask):
return False
if not super().__eq__(value):
return False
if self.spec != value.spec:
return False
return True
25 changes: 13 additions & 12 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,17 +132,18 @@
from distributed.metrics import context_meter
from distributed.shuffle._core import (
NDIndex,
P2PBarrierTask,
ShuffleId,
ShuffleRun,
ShuffleSpec,
barrier_key,
get_worker_plugin,
handle_transfer_errors,
handle_unpack_errors,
p2p_barrier,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._pickle import unpickle_bytestream
from distributed.shuffle._shuffle import barrier_key
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.sizeof import sizeof

Expand All @@ -164,15 +165,12 @@ def rechunk_transfer(
input: np.ndarray,
id: ShuffleId,
input_chunk: NDIndex,
new: ChunkedAxes,
old: ChunkedAxes,
disk: bool,
) -> int:
with handle_transfer_errors(id):
return get_worker_plugin().add_partition(
input,
partition_id=input_chunk,
spec=ArrayRechunkSpec(id=id, new=new, old=old, disk=disk),
id=id,
)


Expand Down Expand Up @@ -815,16 +813,19 @@ def partial_rechunk(
key,
rechunk_transfer,
input_key,
partial_token,
ShuffleId(partial_token),
partial_index,
partial_new,
partial_old,
disk,
)
transfer_keys.append(t.ref())

dsk[_barrier_key] = barrier = Task(
_barrier_key, p2p_barrier, partial_token, transfer_keys
dsk[_barrier_key] = barrier = P2PBarrierTask(
_barrier_key,
p2p_barrier,
partial_token,
transfer_keys,
spec=ArrayRechunkSpec(
id=ShuffleId(partial_token), new=partial_new, old=partial_old, disk=disk
),
)

new_partial_offset = tuple(axis.start for axis in ndpartial.new)
Expand All @@ -835,7 +836,7 @@ def partial_rechunk(
dsk[k] = Task(
k,
rechunk_unpack,
partial_token,
ShuffleId(partial_token),
partial_index,
barrier.ref(),
)
Expand Down
23 changes: 15 additions & 8 deletions distributed/shuffle/_scheduler_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
Expand Down Expand Up @@ -147,41 +148,47 @@ def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
state.participating_workers.add(worker)
return state.run_spec

def _create(self, spec: ShuffleSpec, key: Key, worker: str) -> ShuffleRunSpec:
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec

def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(spec.id)
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[spec.id] = state
self._shuffles[spec.id].add(state)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
spec.id,
shuffle_id,
key,
worker,
)
return state.run_spec

def get_or_create(
self,
spec: ShuffleSpec,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
run_spec = self._get(spec.id, worker)
run_spec = self._get(shuffle_id, worker)
except P2PConsistencyError as e:
return error_message(e)
except KeyError:
try:
run_spec = self._create(spec, key, worker)
run_spec = self._create(shuffle_id, key, worker)
except P2PConsistencyError as e:
return error_message(e)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
Expand Down
42 changes: 19 additions & 23 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)
from distributed.shuffle._core import (
NDIndex,
P2PBarrierTask,
ShuffleId,
ShuffleRun,
ShuffleSpec,
Expand Down Expand Up @@ -70,26 +71,12 @@ def shuffle_transfer(
input: pd.DataFrame,
id: ShuffleId,
input_partition: int,
npartitions: int,
column: str,
meta: pd.DataFrame,
parts_out: set[int],
disk: bool,
drop_column: bool,
) -> int:
with handle_transfer_errors(id):
return get_worker_plugin().add_partition(
input,
input_partition,
spec=DataFrameShuffleSpec(
id=id,
npartitions=npartitions,
column=column,
meta=meta,
parts_out=parts_out,
disk=disk,
drop_column=drop_column,
),
id,
)


Expand Down Expand Up @@ -268,8 +255,9 @@ def cull(

def _construct_graph(self) -> _T_LowLevelGraph:
token = tokenize(self.name_input, self.column, self.npartitions, self.parts_out)
shuffle_id = ShuffleId(token)
dsk: _T_LowLevelGraph = {}
_barrier_key = barrier_key(ShuffleId(token))
_barrier_key = barrier_key(shuffle_id)
name = "shuffle-transfer-" + token
transfer_keys = list()
for i in range(self.npartitions_input):
Expand All @@ -279,17 +267,25 @@ def _construct_graph(self) -> _T_LowLevelGraph:
TaskRef((self.name_input, i)),
token,
i,
self.npartitions,
self.column,
self.meta_input,
self.parts_out,
self.disk,
self.drop_column,
)
dsk[t.key] = t
transfer_keys.append(t.ref())

barrier = Task(_barrier_key, p2p_barrier, token, transfer_keys)
barrier = P2PBarrierTask(
_barrier_key,
p2p_barrier,
token,
transfer_keys,
spec=DataFrameShuffleSpec(
id=shuffle_id,
npartitions=self.npartitions,
column=self.column,
meta=self.meta_input,
parts_out=self.parts_out,
disk=self.disk,
drop_column=self.drop_column,
),
)
dsk[barrier.key] = barrier

name = self.name
Expand Down
37 changes: 12 additions & 25 deletions distributed/shuffle/_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,7 @@

from distributed.core import ErrorMessage, OKMessage, clean_exception, error_message
from distributed.diagnostics.plugin import WorkerPlugin
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
NDIndex,
ShuffleId,
ShuffleRun,
ShuffleRunSpec,
ShuffleSpec,
)
from distributed.shuffle._core import NDIndex, ShuffleId, ShuffleRun, ShuffleRunSpec
from distributed.shuffle._exceptions import P2PConsistencyError, ShuffleClosedError
from distributed.shuffle._limiter import ResourceLimiter
from distributed.utils import log_errors, sync
Expand Down Expand Up @@ -136,24 +129,21 @@ async def get_with_run_id(self, shuffle_id: ShuffleId, run_id: int) -> ShuffleRu
raise shuffle_run._exception
return shuffle_run

async def get_or_create(self, spec: ShuffleSpec, key: Key) -> ShuffleRun:
async def get_or_create(self, shuffle_id: ShuffleId, key: Key) -> ShuffleRun:
"""Get or create a shuffle matching the ID and data spec.

Parameters
----------
shuffle_id
Unique identifier of the shuffle
type:
Type of the shuffle operation
key:
Task key triggering the function
"""
async with self._refresh_locks[spec.id]:
shuffle_run = self._active_runs.get(spec.id, None)
async with self._refresh_locks[shuffle_id]:
shuffle_run = self._active_runs.get(shuffle_id, None)
if shuffle_run is None:
shuffle_run = await self._refresh(
shuffle_id=spec.id,
spec=spec,
shuffle_id=shuffle_id,
key=key,
)

Expand Down Expand Up @@ -189,17 +179,16 @@ async def get_most_recent(
async def _fetch(
self,
shuffle_id: ShuffleId,
spec: ShuffleSpec | None = None,
key: Key | None = None,
) -> ShuffleRunSpec:
if spec is None:
if key is None:
response = await self._plugin.worker.scheduler.shuffle_get(
id=shuffle_id,
worker=self._plugin.worker.address,
)
else:
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
spec=ToPickle(spec),
shuffle_id=shuffle_id,
key=key,
worker=self._plugin.worker.address,
)
Expand All @@ -222,17 +211,15 @@ async def _refresh(
async def _refresh(
self,
shuffle_id: ShuffleId,
spec: ShuffleSpec,
key: Key,
) -> ShuffleRun: ...

async def _refresh(
self,
shuffle_id: ShuffleId,
spec: ShuffleSpec | None = None,
key: Key | None = None,
) -> ShuffleRun:
result = await self._fetch(shuffle_id=shuffle_id, spec=spec, key=key)
result = await self._fetch(shuffle_id=shuffle_id, key=key)
if self.closed:
raise ShuffleClosedError(f"{self} has already been closed")
if existing := self._active_runs.get(shuffle_id, None):
Expand Down Expand Up @@ -355,10 +342,10 @@ def add_partition(
self,
data: Any,
partition_id: int | NDIndex,
spec: ShuffleSpec,
id: ShuffleId,
**kwargs: Any,
) -> int:
shuffle_run = self.get_or_create_shuffle(spec)
shuffle_run = self.get_or_create_shuffle(id)
return shuffle_run.add_partition(
data=data,
partition_id=partition_id,
Expand Down Expand Up @@ -418,13 +405,13 @@ def get_shuffle_run(

def get_or_create_shuffle(
self,
spec: ShuffleSpec,
shuffle_id: ShuffleId,
) -> ShuffleRun:
key = thread_state.key
return sync(
self.worker.loop,
self.shuffle_runs.get_or_create,
spec,
shuffle_id,
key,
)

Expand Down
Loading