Skip to content

Commit

Permalink
Reduce memory footprint of P2P shuffling (#8157)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Sep 14, 2023
1 parent 9129dae commit e57d1c5
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 65 deletions.
2 changes: 1 addition & 1 deletion continuous_integration/environment-3.9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies:
- pre-commit
- prometheus_client
- psutil
- pyarrow=7
- pyarrow=12
- pynvml # Only tested here
- pytest
- pytest-cov
Expand Down
74 changes: 53 additions & 21 deletions distributed/shuffle/_arrow.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

from io import BytesIO
from typing import TYPE_CHECKING
from pathlib import Path
from typing import TYPE_CHECKING, Any

from packaging.version import parse

from dask.utils import parse_bytes

if TYPE_CHECKING:
import pandas as pd
import pyarrow as pa
Expand All @@ -29,34 +31,27 @@ def check_minimal_arrow_version() -> None:
"""Verify that the the correct version of pyarrow is installed to support
the P2P extension.
Raises a RuntimeError in case pyarrow is not installed or installed version
is not recent enough.
Raises a ModuleNotFoundError if pyarrow is not installed or an
ImportError if the installed version is not recent enough.
"""
# First version to introduce Table.sort_by
minversion = "7.0.0"
# First version that supports concatenating extension arrays (apache/arrow#14463)
minversion = "12.0.0"
try:
import pyarrow as pa
except ImportError:
raise RuntimeError(f"P2P shuffling requires pyarrow>={minversion}")

except ModuleNotFoundError:
raise ModuleNotFoundError(f"P2P shuffling requires pyarrow>={minversion}")
if parse(pa.__version__) < parse(minversion):
raise RuntimeError(
raise ImportError(
f"P2P shuffling requires pyarrow>={minversion} but only found {pa.__version__}"
)


def convert_partition(data: bytes, meta: pd.DataFrame) -> pd.DataFrame:
def convert_shards(shards: list[pa.Table], meta: pd.DataFrame) -> pd.DataFrame:
import pyarrow as pa

from dask.dataframe.dispatch import from_pyarrow_table_dispatch

file = BytesIO(data)
end = len(data)
shards = []
while file.tell() < end:
sr = pa.RecordBatchStreamReader(file)
shards.append(sr.read_all())
table = pa.concat_tables(shards, promote=True)
table = pa.concat_tables(shards)

df = from_pyarrow_table_dispatch(meta, table, self_destruct=True)
return df.astype(meta.dtypes, copy=False)
Expand All @@ -66,9 +61,7 @@ def list_of_buffers_to_table(data: list[bytes]) -> pa.Table:
"""Convert a list of arrow buffers and a schema to an Arrow Table"""
import pyarrow as pa

return pa.concat_tables(
(deserialize_table(buffer) for buffer in data), promote=True
)
return pa.concat_tables(deserialize_table(buffer) for buffer in data)


def serialize_table(table: pa.Table) -> bytes:
Expand All @@ -85,3 +78,42 @@ def deserialize_table(buffer: bytes) -> pa.Table:

with pa.ipc.open_stream(pa.py_buffer(buffer)) as reader:
return reader.read_all()


def read_from_disk(path: Path, meta: pd.DataFrame) -> tuple[Any, int]:
import pyarrow as pa

from dask.dataframe.dispatch import pyarrow_schema_dispatch

batch_size = parse_bytes("1 MiB")
batch = []
shards = []
schema = pyarrow_schema_dispatch(meta, preserve_index=True)

with pa.OSFile(str(path), mode="rb") as f:
size = f.seek(0, whence=2)
f.seek(0)
prev = 0
offset = f.tell()
while offset < size:
sr = pa.RecordBatchStreamReader(f)
shard = sr.read_all()
offset = f.tell()
batch.append(shard)

if offset - prev >= batch_size:
table = pa.concat_tables(batch)
shards.append(_copy_table(table, schema))
batch = []
prev = offset
if batch:
table = pa.concat_tables(batch)
shards.append(_copy_table(table, schema))
return shards, size


def _copy_table(table: pa.Table, schema: pa.Schema) -> pa.Table:
import pyarrow as pa

arrs = [pa.concat_arrays(column.chunks) for column in table.columns]
return pa.table(data=arrs, schema=schema)
11 changes: 8 additions & 3 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar

from distributed.core import PooledRPCCall
Expand Down Expand Up @@ -62,6 +63,7 @@ def __init__(

self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)

Expand Down Expand Up @@ -180,10 +182,9 @@ def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception

def _read_from_disk(self, id: NDIndex) -> bytes:
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
data: bytes = self._disk_buffer.read("_".join(str(i) for i in id))
return data
return self._disk_buffer.read("_".join(str(i) for i in id))

async def receive(self, data: list[tuple[_T_partition_id, bytes]]) -> None:
await self._receive(data)
Expand Down Expand Up @@ -238,6 +239,10 @@ async def _get_output_partition(
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""

@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""


def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
Expand Down
11 changes: 5 additions & 6 deletions distributed/shuffle/_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import pathlib
import shutil
from typing import Any, Callable

from distributed.shuffle._buffer import ShardsBuffer
from distributed.shuffle._limiter import ResourceLimiter
Expand Down Expand Up @@ -41,6 +42,7 @@ class DiskShardsBuffer(ShardsBuffer):
def __init__(
self,
directory: str | pathlib.Path,
read: Callable[[pathlib.Path], tuple[Any, int]],
memory_limiter: ResourceLimiter | None = None,
):
super().__init__(
Expand All @@ -50,6 +52,7 @@ def __init__(
)
self.directory = pathlib.Path(directory)
self.directory.mkdir(exist_ok=True)
self._read = read

async def _process(self, id: str, shards: list[bytes]) -> None:
"""Write one buffer to file
Expand All @@ -74,19 +77,15 @@ async def _process(self, id: str, shards: list[bytes]) -> None:
for shard in shards:
f.write(shard)

def read(self, id: int | str) -> bytes:
def read(self, id: int | str) -> Any:
"""Read a complete file back into memory"""
self.raise_on_exception()
if not self._inputs_done:
raise RuntimeError("Tried to read from file before done.")

try:
with self.time("read"):
with open(
self.directory / str(id), mode="rb", buffering=100_000_000
) as f:
data = f.read()
size = f.tell()
data, size = self._read((self.directory / str(id)).resolve())
except FileNotFoundError:
raise KeyError(id)

Expand Down
38 changes: 23 additions & 15 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@
from collections.abc import Callable, Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from io import BytesIO
from itertools import product
from pathlib import Path
from typing import TYPE_CHECKING, Any, NamedTuple

import dask
Expand Down Expand Up @@ -258,26 +258,22 @@ def split_axes(old: ChunkedAxes, new: ChunkedAxes) -> SplitAxes:
return axes


def convert_chunk(data: bytes) -> np.ndarray:
def convert_chunk(shards: list[tuple[NDIndex, np.ndarray]]) -> np.ndarray:
import numpy as np

from dask.array.core import concatenate3

file = BytesIO(data)
shards: dict[NDIndex, np.ndarray] = {}
indexed: dict[NDIndex, np.ndarray] = {}
for index, shard in shards:
indexed[index] = shard
del shards

while file.tell() < len(data):
for index, shard in pickle.load(file):
shards[index] = shard

subshape = [max(dim) + 1 for dim in zip(*shards.keys())]
assert len(shards) == np.prod(subshape)
subshape = [max(dim) + 1 for dim in zip(*indexed.keys())]
assert len(indexed) == np.prod(subshape)

rec_cat_arg = np.empty(subshape, dtype="O")
for index, shard in shards.items():
for index, shard in indexed.items():
rec_cat_arg[tuple(index)] = shard
del data
del file
arrs = rec_cat_arg.tolist()
return concatenate3(arrs)

Expand Down Expand Up @@ -427,8 +423,20 @@ def _() -> dict[str, tuple[NDIndex, bytes]]:
async def _get_output_partition(
self, partition_id: NDIndex, key: str, **kwargs: Any
) -> np.ndarray:
data = self._read_from_disk(partition_id)
return await self.offload(convert_chunk, data)
def _(partition_id: NDIndex) -> np.ndarray:
data = self._read_from_disk(partition_id)
return convert_chunk(data)

return await self.offload(_, partition_id)

def read(self, path: Path) -> tuple[Any, int]:
shards: list[tuple[NDIndex, np.ndarray]] = []
with path.open(mode="rb") as f:
size = f.seek(0, os.SEEK_END)
f.seek(0)
while f.tell() < size:
shards.extend(pickle.load(f))
return shards, size

def _get_assigned_worker(self, id: NDIndex) -> str:
return self.worker_for[id]
Expand Down
21 changes: 17 additions & 4 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any

import toolz
Expand All @@ -20,8 +21,9 @@
from distributed.shuffle._arrow import (
check_dtype_support,
check_minimal_arrow_version,
convert_partition,
convert_shards,
list_of_buffers_to_table,
read_from_disk,
serialize_table,
)
from distributed.shuffle._core import (
Expand Down Expand Up @@ -321,7 +323,7 @@ def split_by_worker(
return out


def split_by_partition(t: pa.Table, column: str) -> dict[Any, pa.Table]:
def split_by_partition(t: pa.Table, column: str) -> dict[int, pa.Table]:
"""
Split data into many arrow batches, partitioned by final partition
"""
Expand Down Expand Up @@ -383,6 +385,11 @@ class DataFrameShuffleRun(ShuffleRun[int, "pd.DataFrame"]):
buffer.
"""

column: str
meta: pd.DataFrame
partitions_of: dict[str, list[int]]
worker_for: pd.Series

def __init__(
self,
worker_for: dict[int, str],
Expand Down Expand Up @@ -476,16 +483,22 @@ async def _get_output_partition(
**kwargs: Any,
) -> pd.DataFrame:
try:
data = self._read_from_disk((partition_id,))

out = await self.offload(convert_partition, data, self.meta)
def _(partition_id: int, meta: pd.DataFrame) -> pd.DataFrame:
data = self._read_from_disk((partition_id,))
return convert_shards(data, meta)

out = await self.offload(_, partition_id, self.meta)
except KeyError:
out = self.meta.copy()
return out

def _get_assigned_worker(self, id: int) -> str:
return self.worker_for[id]

def read(self, path: Path) -> tuple[Any, int]:
return read_from_disk(path, self.meta)


@dataclass(frozen=True)
class DataFrameShuffleSpec(ShuffleSpec[int]):
Expand Down
Loading

0 comments on commit e57d1c5

Please sign in to comment.