Skip to content

Commit

Permalink
Allow P2P to store data in-memory (dask#8279)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Oct 19, 2023
1 parent ce813eb commit cbc3a33
Show file tree
Hide file tree
Showing 12 changed files with 228 additions and 24 deletions.
4 changes: 4 additions & 0 deletions distributed/distributed-schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,10 @@ properties:
max:
type: string
description: The maximum delay between retries
disk:
type: boolean
description: |
Whether or not P2P stores intermediate data on disk instead of memory
dashboard:
type: object
Expand Down
1 change: 1 addition & 0 deletions distributed/distributed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ distributed:
delay:
min: 1s # the first non-zero delay between re-tries
max: 30s # the maximum delay between re-tries
disk: True

###################
# Bokeh dashboard #
Expand Down
32 changes: 26 additions & 6 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import ShuffleClosedError
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils_comm import retry

if TYPE_CHECKING:
Expand All @@ -47,6 +48,17 @@


class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
diagnostics: dict[str, float]

def __init__(
self,
id: ShuffleId,
Expand All @@ -58,6 +70,7 @@ def __init__(
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
):
self.id = id
self.run_id = run_id
Expand All @@ -66,12 +79,14 @@ def __init__(
self.rpc = rpc
self.scheduler = scheduler
self.closed = False

self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)

self._comm_buffer = CommShardsBuffer(
send=self.send, memory_limiter=memory_limiter_comms
Expand Down Expand Up @@ -270,6 +285,10 @@ async def _get_output_partition(
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""

@abc.abstractmethod
def deserialize(self, buffer: bytes) -> Any:
"""Deserialize shards"""


def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
Expand Down Expand Up @@ -321,6 +340,7 @@ def id(self) -> ShuffleId:
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool

def create_new_run(
self,
Expand Down
2 changes: 1 addition & 1 deletion distributed/shuffle/_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ async def _process(self, id: str, shards: list[bytes]) -> None:
for shard in shards:
f.write(shard)

def read(self, id: int | str) -> Any:
def read(self, id: str) -> Any:
"""Read a complete file back into memory"""
self.raise_on_exception()
if not self._inputs_done:
Expand Down
49 changes: 49 additions & 0 deletions distributed/shuffle/_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

from collections import defaultdict, deque
from typing import Any, Callable

from dask.sizeof import sizeof

from distributed.shuffle._buffer import ShardsBuffer
from distributed.shuffle._limiter import ResourceLimiter
from distributed.utils import log_errors


class MemoryShardsBuffer(ShardsBuffer):
_deserialize: Callable[[bytes], Any]
_shards: defaultdict[str, deque[bytes]]

def __init__(self, deserialize: Callable[[bytes], Any]) -> None:
super().__init__(
memory_limiter=ResourceLimiter(None),
)
self._deserialize = deserialize
self._shards = defaultdict(deque)

async def _process(self, id: str, shards: list[bytes]) -> None:
# TODO: This can be greatly simplified, there's no need for
# background threads at all.
with log_errors():
with self.time("write"):
self._shards[id].extend(shards)

def read(self, id: str) -> Any:
self.raise_on_exception()
if not self._inputs_done:
raise RuntimeError("Tried to read from file before done.")

with self.time("read"):
data = []
size = 0
shards = self._shards[id]
while shards:
shard = shards.pop()
data.append(self._deserialize(shard))
size += sizeof(shards)

if data:
self.bytes_read += size
return data
else:
raise KeyError(id)
10 changes: 10 additions & 0 deletions distributed/shuffle/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, Any

import dask
from dask.base import is_dask_collection, tokenize
from dask.highlevelgraph import HighLevelGraph
from dask.layers import Layer
Expand Down Expand Up @@ -106,6 +107,7 @@ def hash_join_p2p(
lhs = _calculate_partitions(lhs, left_on, npartitions)
rhs = _calculate_partitions(rhs, right_on, npartitions)
merge_name = "hash-join-" + tokenize(lhs, rhs, **merge_kwargs)
disk: bool = dask.config.get("distributed.p2p.disk")
join_layer = HashJoinP2PLayer(
name=merge_name,
name_input_left=lhs._name,
Expand All @@ -123,6 +125,7 @@ def hash_join_p2p(
indicator=indicator,
left_index=left_index,
right_index=right_index,
disk=disk,
)
graph = HighLevelGraph.from_collections(
merge_name, join_layer, dependencies=[lhs, rhs]
Expand All @@ -142,6 +145,7 @@ def merge_transfer(
npartitions: int,
meta: pd.DataFrame,
parts_out: set[int],
disk: bool,
):
return shuffle_transfer(
input=input,
Expand All @@ -151,6 +155,7 @@ def merge_transfer(
column=_HASH_COLUMN_NAME,
meta=meta,
parts_out=parts_out,
disk=disk,
)


Expand Down Expand Up @@ -227,6 +232,7 @@ def __init__(
left_index: bool,
right_index: bool,
npartitions: int,
disk: bool,
how: MergeHow = "inner",
suffixes: Suffixes = ("_x", "_y"),
indicator: bool = False,
Expand All @@ -251,6 +257,7 @@ def __init__(
self.n_partitions_right = n_partitions_right
self.left_index = left_index
self.right_index = right_index
self.disk = disk
annotations = annotations or {}
annotations.update({"shuffle": lambda key: key[-1]})
super().__init__(annotations=annotations)
Expand Down Expand Up @@ -332,6 +339,7 @@ def _cull(self, parts_out: Sequence[str]):
parts_out=parts_out,
left_index=self.left_index,
right_index=self.right_index,
disk=self.disk,
annotations=self.annotations,
n_partitions_left=self.n_partitions_left,
n_partitions_right=self.n_partitions_right,
Expand Down Expand Up @@ -381,6 +389,7 @@ def _construct_graph(self) -> dict[tuple | str, tuple]:
self.npartitions,
self.meta_input_left,
self.parts_out,
self.disk,
)
for i in range(self.n_partitions_right):
transfer_keys_right.append((name_right, i))
Expand All @@ -392,6 +401,7 @@ def _construct_graph(self) -> dict[tuple | str, tuple]:
self.npartitions,
self.meta_input_right,
self.parts_out,
self.disk,
)

_barrier_key_left = barrier_key(ShuffleId(token_left))
Expand Down
23 changes: 17 additions & 6 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,13 @@ def rechunk_transfer(
input_chunk: NDIndex,
new: ChunkedAxes,
old: ChunkedAxes,
disk: bool,
) -> int:
with handle_transfer_errors(id):
return get_worker_plugin().add_partition(
input,
partition_id=input_chunk,
spec=ArrayRechunkSpec(id=id, new=new, old=old),
spec=ArrayRechunkSpec(id=id, new=new, old=old, disk=disk),
)


Expand All @@ -174,6 +175,7 @@ def rechunk_p2p(x: da.Array, chunks: ChunkedAxes) -> da.Array:
token = tokenize(x, chunks)
_barrier_key = barrier_key(ShuffleId(token))
name = f"rechunk-transfer-{token}"
disk: bool = dask.config.get("distributed.p2p.disk")
transfer_keys = []
for index in np.ndindex(tuple(len(dim) for dim in x.chunks)):
transfer_keys.append((name,) + index)
Expand All @@ -184,6 +186,7 @@ def rechunk_p2p(x: da.Array, chunks: ChunkedAxes) -> da.Array:
index,
chunks,
x.chunks,
disk,
)

dsk[_barrier_key] = (shuffle_barrier, token, transfer_keys)
Expand Down Expand Up @@ -258,14 +261,15 @@ def split_axes(old: ChunkedAxes, new: ChunkedAxes) -> SplitAxes:
return axes


def convert_chunk(shards: list[tuple[NDIndex, np.ndarray]]) -> np.ndarray:
def convert_chunk(shards: list[list[tuple[NDIndex, np.ndarray]]]) -> np.ndarray:
import numpy as np

from dask.array.core import concatenate3

indexed: dict[NDIndex, np.ndarray] = {}
for index, shard in shards:
indexed[index] = shard
for sublist in shards:
for index, shard in sublist:
indexed[index] = shard
del shards

subshape = [max(dim) + 1 for dim in zip(*indexed.keys())]
Expand Down Expand Up @@ -333,6 +337,7 @@ def __init__(
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
):
super().__init__(
id=id,
Expand All @@ -344,6 +349,7 @@ def __init__(
scheduler=scheduler,
memory_limiter_comms=memory_limiter_comms,
memory_limiter_disk=memory_limiter_disk,
disk=disk,
)
self.old = old
self.new = new
Expand Down Expand Up @@ -429,13 +435,17 @@ def _(partition_id: NDIndex) -> np.ndarray:

return await self.offload(_, partition_id)

def deserialize(self, buffer: bytes) -> Any:
result = pickle.loads(buffer)
return result

def read(self, path: Path) -> tuple[Any, int]:
shards: list[tuple[NDIndex, np.ndarray]] = []
shards: list[list[tuple[NDIndex, np.ndarray]]] = []
with path.open(mode="rb") as f:
size = f.seek(0, os.SEEK_END)
f.seek(0)
while f.tell() < size:
shards.extend(pickle.load(f))
shards.append(pickle.load(f))
return shards, size

def _get_assigned_worker(self, id: NDIndex) -> str:
Expand Down Expand Up @@ -475,6 +485,7 @@ def create_run_on_worker(
scheduler=plugin.worker.scheduler,
memory_limiter_disk=plugin.memory_limiter_disk,
memory_limiter_comms=plugin.memory_limiter_comms,
disk=self.disk,
)


Expand Down
Loading

0 comments on commit cbc3a33

Please sign in to comment.