Skip to content

Commit

Permalink
Invert offloading between event loop and threads for P2P (#8322)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Nov 2, 2023
1 parent 76dd800 commit 954e9d0
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 94 deletions.
39 changes: 26 additions & 13 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar

from tornado.ioloop import IOLoop

import dask.config
from dask.typing import Key
from dask.utils import parse_timedelta
Expand All @@ -26,6 +28,7 @@
from distributed.shuffle._exceptions import ShuffleClosedError
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import sync
from distributed.utils_comm import retry

if TYPE_CHECKING:
Expand Down Expand Up @@ -58,6 +61,12 @@ class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
diagnostics: dict[str, float]
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop

def __init__(
self,
Expand All @@ -71,6 +80,7 @@ def __init__(
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
Expand All @@ -94,13 +104,14 @@ def __init__(
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)

self.diagnostics: dict[str, float] = defaultdict(float)
self.diagnostics = defaultdict(float)
self.transferred = False
self.received: set[_T_partition_id] = set()
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception: Exception | None = None
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop

def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
Expand Down Expand Up @@ -251,32 +262,34 @@ def _get_assigned_worker(self, i: _T_partition_id) -> str:
async def _receive(self, data: list[tuple[_T_partition_id, bytes]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""

async def add_partition(
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
return await self._add_partition(data, partition_id)
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id

@abc.abstractmethod
async def _add_partition(
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
"""Add an input partition to the shuffle run"""
) -> dict[str, tuple[_T_partition_id, bytes]]:
"""Shard an input partition by the assigned output workers"""

async def get_output_partition(
def get_output_partition(
self, partition_id: _T_partition_id, key: str, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
await self._ensure_output_worker(partition_id, key)
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
await self.flush_receive()
return await self._get_output_partition(partition_id, key, **kwargs)
sync(self._loop, self.flush_receive)
return self._get_output_partition(partition_id, key, **kwargs)

@abc.abstractmethod
async def _get_output_partition(
def _get_output_partition(
self, partition_id: _T_partition_id, key: str, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
Expand Down
58 changes: 21 additions & 37 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, NamedTuple

from tornado.ioloop import IOLoop

import dask
from dask.base import tokenize
from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
Expand Down Expand Up @@ -338,6 +340,7 @@ def __init__(
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
super().__init__(
id=id,
Expand All @@ -350,6 +353,7 @@ def __init__(
memory_limiter_comms=memory_limiter_comms,
memory_limiter_disk=memory_limiter_disk,
disk=disk,
loop=loop,
)
self.old = old
self.new = new
Expand Down Expand Up @@ -391,49 +395,28 @@ def _repartition_shards(self, data: list[bytes]) -> dict[NDIndex, bytes]:
repartitioned[id].append(shard)
return {k: pickle.dumps(v) for k, v in repartitioned.items()}

async def _add_partition(
def _shard_partition(
self, data: np.ndarray, partition_id: NDIndex, **kwargs: Any
) -> int:
def _() -> dict[str, tuple[NDIndex, bytes]]:
"""Return a mapping of worker addresses to a tuple of input partition
IDs and shard data.
TODO: Overhaul!
As shard data, we serialize the payload together with the sub-index of the
slice within the new chunk. To assemble the new chunk from its shards, it
needs the sub-index to know where each shard belongs within the chunk.
Adding the sub-index into the serialized payload on the sender allows us to
write the serialized payload directly to disk on the receiver.
"""
out: dict[
str, list[tuple[NDIndex, tuple[NDIndex, np.ndarray]]]
] = defaultdict(list)
from itertools import product

ndsplits = product(
*(axis[i] for axis, i in zip(self.split_axes, partition_id))
)
) -> dict[str, tuple[NDIndex, bytes]]:
out: dict[str, list[tuple[NDIndex, tuple[NDIndex, np.ndarray]]]] = defaultdict(
list
)
from itertools import product

for ndsplit in ndsplits:
chunk_index, shard_index, ndslice = zip(*ndsplit)
out[self.worker_for[chunk_index]].append(
(chunk_index, (shard_index, data[ndslice]))
)
return {k: (partition_id, pickle.dumps(v)) for k, v in out.items()}
ndsplits = product(*(axis[i] for axis, i in zip(self.split_axes, partition_id)))

out = await self.offload(_)
await self._write_to_comm(out)
return self.run_id
for ndsplit in ndsplits:
chunk_index, shard_index, ndslice = zip(*ndsplit)
out[self.worker_for[chunk_index]].append(
(chunk_index, (shard_index, data[ndslice]))
)
return {k: (partition_id, pickle.dumps(v)) for k, v in out.items()}

async def _get_output_partition(
def _get_output_partition(
self, partition_id: NDIndex, key: str, **kwargs: Any
) -> np.ndarray:
def _(partition_id: NDIndex) -> np.ndarray:
data = self._read_from_disk(partition_id)
return convert_chunk(data)

return await self.offload(_, partition_id)
data = self._read_from_disk(partition_id)
return convert_chunk(data)

def deserialize(self, buffer: bytes) -> Any:
result = pickle.loads(buffer)
Expand Down Expand Up @@ -486,6 +469,7 @@ def create_run_on_worker(
memory_limiter_disk=plugin.memory_limiter_disk,
memory_limiter_comms=plugin.memory_limiter_comms,
disk=self.disk,
loop=plugin.worker.loop,
)


Expand Down
42 changes: 18 additions & 24 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import TYPE_CHECKING, Any

import toolz
from tornado.ioloop import IOLoop

import dask
from dask.base import tokenize
Expand Down Expand Up @@ -431,6 +432,7 @@ def __init__(
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
import pandas as pd

Expand All @@ -445,6 +447,7 @@ def __init__(
memory_limiter_comms=memory_limiter_comms,
memory_limiter_disk=memory_limiter_disk,
disk=disk,
loop=loop,
)
self.column = column
self.meta = meta
Expand Down Expand Up @@ -484,42 +487,32 @@ def _repartition_buffers(self, data: list[bytes]) -> dict[NDIndex, bytes]:
del data
return {(k,): serialize_table(v) for k, v in groups.items()}

async def _add_partition(
def _shard_partition(
self,
data: pd.DataFrame,
partition_id: int,
**kwargs: Any,
) -> int:
def _() -> dict[str, tuple[int, bytes]]:
out = split_by_worker(
data,
self.column,
self.meta,
self.worker_for,
)
out = {k: (partition_id, serialize_table(t)) for k, t in out.items()}
return out

out = await self.offload(_)
await self._write_to_comm(out)
return self.run_id
) -> dict[str, tuple[int, bytes]]:
out = split_by_worker(
data,
self.column,
self.meta,
self.worker_for,
)
out = {k: (partition_id, serialize_table(t)) for k, t in out.items()}
return out

async def _get_output_partition(
def _get_output_partition(
self,
partition_id: int,
key: str,
**kwargs: Any,
) -> pd.DataFrame:
try:

def _(partition_id: int, meta: pd.DataFrame) -> pd.DataFrame:
data = self._read_from_disk((partition_id,))
return convert_shards(data, meta)

out = await self.offload(_, partition_id, self.meta)
data = self._read_from_disk((partition_id,))
return convert_shards(data, self.meta)
except KeyError:
out = self.meta.copy()
return out
return self.meta.copy()

def _get_assigned_worker(self, id: int) -> str:
return self.worker_for[id]
Expand Down Expand Up @@ -564,6 +557,7 @@ def create_run_on_worker(
else ResourceLimiter(None),
memory_limiter_comms=plugin.memory_limiter_comms,
disk=self.disk,
loop=plugin.worker.loop,
)


Expand Down
8 changes: 2 additions & 6 deletions distributed/shuffle/_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,7 @@ def add_partition(
**kwargs: Any,
) -> int:
shuffle_run = self.get_or_create_shuffle(spec)
return sync(
self.worker.loop,
shuffle_run.add_partition,
return shuffle_run.add_partition(
data=data,
partition_id=partition_id,
**kwargs,
Expand Down Expand Up @@ -428,9 +426,7 @@ def get_output_partition(
"""
shuffle_run = self.get_shuffle_run(shuffle_id, run_id)
key = thread_state.key
return sync(
self.worker.loop,
shuffle_run.get_output_partition,
return shuffle_run.get_output_partition(
partition_id=partition_id,
key=key,
meta=meta,
Expand Down
16 changes: 10 additions & 6 deletions distributed/shuffle/tests/test_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from concurrent.futures import ThreadPoolExecutor

from tornado.ioloop import IOLoop

import dask
from dask.array.core import concatenate3
from dask.array.rechunk import normalize_chunks, rechunk
Expand Down Expand Up @@ -71,6 +73,7 @@ def new_shuffle(
memory_limiter_disk=ResourceLimiter(10000000),
memory_limiter_comms=ResourceLimiter(10000000),
disk=disk,
loop=loop,
)
self.shuffles[name] = s
return s
Expand All @@ -83,9 +86,8 @@ def new_shuffle(
@pytest.mark.parametrize("barrier_first_worker", [True, False])
@pytest.mark.parametrize("disk", [True, False])
@gen_test()
async def test_lowlevel_rechunk(
tmp_path, loop_in_thread, n_workers, barrier_first_worker, disk
):
async def test_lowlevel_rechunk(tmp_path, n_workers, barrier_first_worker, disk):
loop = IOLoop.current()
old = ((1, 2, 3, 4), (5,) * 6)
new = ((5, 5), (12, 18))

Expand Down Expand Up @@ -115,7 +117,7 @@ async def test_lowlevel_rechunk(
old=old,
new=new,
directory=tmp_path,
loop=loop_in_thread,
loop=loop,
disk=disk,
)
)
Expand All @@ -129,7 +131,7 @@ async def test_lowlevel_rechunk(
try:
for i, (idx, arr) in enumerate(old_chunks.items()):
s = shuffles[i % len(shuffles)]
run_ids.append(await s.add_partition(arr, idx))
run_ids.append(await asyncio.to_thread(s.add_partition, arr, idx))

await barrier_worker.barrier(run_ids)

Expand All @@ -148,7 +150,9 @@ async def test_lowlevel_rechunk(
all_chunks = np.empty(tuple(len(dim) for dim in new), dtype="O")
for ix, worker in worker_for_mapping.items():
s = local_shuffle_pool.shuffles[worker]
all_chunks[ix] = await s.get_output_partition(ix, f"key-{ix}")
all_chunks[ix] = await asyncio.to_thread(
s.get_output_partition, ix, f"key-{ix}"
)

finally:
await asyncio.gather(*[s.close() for s in shuffles])
Expand Down
Loading

0 comments on commit 954e9d0

Please sign in to comment.