From cbc3a33cc32e69e34cfc53f74cf4a5416566a5a4 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 19 Oct 2023 12:14:05 +0200 Subject: [PATCH] Allow P2P to store data in-memory (#8279) --- distributed/distributed-schema.yaml | 4 ++ distributed/distributed.yaml | 1 + distributed/shuffle/_core.py | 32 ++++++++-- distributed/shuffle/_disk.py | 2 +- distributed/shuffle/_memory.py | 49 +++++++++++++++ distributed/shuffle/_merge.py | 10 +++ distributed/shuffle/_rechunk.py | 23 +++++-- distributed/shuffle/_shuffle.py | 25 +++++++- .../shuffle/tests/test_memory_buffer.py | 62 +++++++++++++++++++ distributed/shuffle/tests/test_merge.py | 9 ++- distributed/shuffle/tests/test_rechunk.py | 18 ++++-- distributed/shuffle/tests/test_shuffle.py | 17 ++++- 12 files changed, 228 insertions(+), 24 deletions(-) create mode 100644 distributed/shuffle/_memory.py create mode 100644 distributed/shuffle/tests/test_memory_buffer.py diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index 6aaed3f3770..47e1610cde3 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -1066,6 +1066,10 @@ properties: max: type: string description: The maximum delay between retries + disk: + type: boolean + description: | + Whether or not P2P stores intermediate data on disk instead of memory dashboard: type: object diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 266f487fd6a..fe379046362 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -304,6 +304,7 @@ distributed: delay: min: 1s # the first non-zero delay between re-tries max: 30s # the maximum delay between re-tries + disk: True ################### # Bokeh dashboard # diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index d20ea106242..c09d5f04363 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -25,6 +25,7 @@ from distributed.shuffle._disk import DiskShardsBuffer from distributed.shuffle._exceptions import ShuffleClosedError from distributed.shuffle._limiter import ResourceLimiter +from distributed.shuffle._memory import MemoryShardsBuffer from distributed.utils_comm import retry if TYPE_CHECKING: @@ -47,6 +48,17 @@ class ShuffleRun(Generic[_T_partition_id, _T_partition_type]): + id: ShuffleId + run_id: int + local_address: str + executor: ThreadPoolExecutor + rpc: Callable[[str], PooledRPCCall] + scheduler: PooledRPCCall + closed: bool + _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer + _comm_buffer: CommShardsBuffer + diagnostics: dict[str, float] + def __init__( self, id: ShuffleId, @@ -58,6 +70,7 @@ def __init__( scheduler: PooledRPCCall, memory_limiter_disk: ResourceLimiter, memory_limiter_comms: ResourceLimiter, + disk: bool, ): self.id = id self.run_id = run_id @@ -66,12 +79,14 @@ def __init__( self.rpc = rpc self.scheduler = scheduler self.closed = False - - self._disk_buffer = DiskShardsBuffer( - directory=directory, - read=self.read, - memory_limiter=memory_limiter_disk, - ) + if disk: + self._disk_buffer = DiskShardsBuffer( + directory=directory, + read=self.read, + memory_limiter=memory_limiter_disk, + ) + else: + self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize) self._comm_buffer = CommShardsBuffer( send=self.send, memory_limiter=memory_limiter_comms @@ -270,6 +285,10 @@ async def _get_output_partition( def read(self, path: Path) -> tuple[Any, int]: """Read shards from disk""" + @abc.abstractmethod + def deserialize(self, buffer: bytes) -> Any: + """Deserialize shards""" + def get_worker_plugin() -> ShuffleWorkerPlugin: from distributed import get_worker @@ -321,6 +340,7 @@ def id(self) -> ShuffleId: @dataclass(frozen=True) class ShuffleSpec(abc.ABC, Generic[_T_partition_id]): id: ShuffleId + disk: bool def create_new_run( self, diff --git a/distributed/shuffle/_disk.py b/distributed/shuffle/_disk.py index da7aa1ea352..87fea2cb99f 100644 --- a/distributed/shuffle/_disk.py +++ b/distributed/shuffle/_disk.py @@ -163,7 +163,7 @@ async def _process(self, id: str, shards: list[bytes]) -> None: for shard in shards: f.write(shard) - def read(self, id: int | str) -> Any: + def read(self, id: str) -> Any: """Read a complete file back into memory""" self.raise_on_exception() if not self._inputs_done: diff --git a/distributed/shuffle/_memory.py b/distributed/shuffle/_memory.py new file mode 100644 index 00000000000..262073b31cd --- /dev/null +++ b/distributed/shuffle/_memory.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from collections import defaultdict, deque +from typing import Any, Callable + +from dask.sizeof import sizeof + +from distributed.shuffle._buffer import ShardsBuffer +from distributed.shuffle._limiter import ResourceLimiter +from distributed.utils import log_errors + + +class MemoryShardsBuffer(ShardsBuffer): + _deserialize: Callable[[bytes], Any] + _shards: defaultdict[str, deque[bytes]] + + def __init__(self, deserialize: Callable[[bytes], Any]) -> None: + super().__init__( + memory_limiter=ResourceLimiter(None), + ) + self._deserialize = deserialize + self._shards = defaultdict(deque) + + async def _process(self, id: str, shards: list[bytes]) -> None: + # TODO: This can be greatly simplified, there's no need for + # background threads at all. + with log_errors(): + with self.time("write"): + self._shards[id].extend(shards) + + def read(self, id: str) -> Any: + self.raise_on_exception() + if not self._inputs_done: + raise RuntimeError("Tried to read from file before done.") + + with self.time("read"): + data = [] + size = 0 + shards = self._shards[id] + while shards: + shard = shards.pop() + data.append(self._deserialize(shard)) + size += sizeof(shards) + + if data: + self.bytes_read += size + return data + else: + raise KeyError(id) diff --git a/distributed/shuffle/_merge.py b/distributed/shuffle/_merge.py index b14fd8fce53..1a6bb4c2825 100644 --- a/distributed/shuffle/_merge.py +++ b/distributed/shuffle/_merge.py @@ -4,6 +4,7 @@ from collections.abc import Iterable, Sequence from typing import TYPE_CHECKING, Any +import dask from dask.base import is_dask_collection, tokenize from dask.highlevelgraph import HighLevelGraph from dask.layers import Layer @@ -106,6 +107,7 @@ def hash_join_p2p( lhs = _calculate_partitions(lhs, left_on, npartitions) rhs = _calculate_partitions(rhs, right_on, npartitions) merge_name = "hash-join-" + tokenize(lhs, rhs, **merge_kwargs) + disk: bool = dask.config.get("distributed.p2p.disk") join_layer = HashJoinP2PLayer( name=merge_name, name_input_left=lhs._name, @@ -123,6 +125,7 @@ def hash_join_p2p( indicator=indicator, left_index=left_index, right_index=right_index, + disk=disk, ) graph = HighLevelGraph.from_collections( merge_name, join_layer, dependencies=[lhs, rhs] @@ -142,6 +145,7 @@ def merge_transfer( npartitions: int, meta: pd.DataFrame, parts_out: set[int], + disk: bool, ): return shuffle_transfer( input=input, @@ -151,6 +155,7 @@ def merge_transfer( column=_HASH_COLUMN_NAME, meta=meta, parts_out=parts_out, + disk=disk, ) @@ -227,6 +232,7 @@ def __init__( left_index: bool, right_index: bool, npartitions: int, + disk: bool, how: MergeHow = "inner", suffixes: Suffixes = ("_x", "_y"), indicator: bool = False, @@ -251,6 +257,7 @@ def __init__( self.n_partitions_right = n_partitions_right self.left_index = left_index self.right_index = right_index + self.disk = disk annotations = annotations or {} annotations.update({"shuffle": lambda key: key[-1]}) super().__init__(annotations=annotations) @@ -332,6 +339,7 @@ def _cull(self, parts_out: Sequence[str]): parts_out=parts_out, left_index=self.left_index, right_index=self.right_index, + disk=self.disk, annotations=self.annotations, n_partitions_left=self.n_partitions_left, n_partitions_right=self.n_partitions_right, @@ -381,6 +389,7 @@ def _construct_graph(self) -> dict[tuple | str, tuple]: self.npartitions, self.meta_input_left, self.parts_out, + self.disk, ) for i in range(self.n_partitions_right): transfer_keys_right.append((name_right, i)) @@ -392,6 +401,7 @@ def _construct_graph(self) -> dict[tuple | str, tuple]: self.npartitions, self.meta_input_right, self.parts_out, + self.disk, ) _barrier_key_left = barrier_key(ShuffleId(token_left)) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index f0efd4f66f3..8fb8172508d 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -143,12 +143,13 @@ def rechunk_transfer( input_chunk: NDIndex, new: ChunkedAxes, old: ChunkedAxes, + disk: bool, ) -> int: with handle_transfer_errors(id): return get_worker_plugin().add_partition( input, partition_id=input_chunk, - spec=ArrayRechunkSpec(id=id, new=new, old=old), + spec=ArrayRechunkSpec(id=id, new=new, old=old, disk=disk), ) @@ -174,6 +175,7 @@ def rechunk_p2p(x: da.Array, chunks: ChunkedAxes) -> da.Array: token = tokenize(x, chunks) _barrier_key = barrier_key(ShuffleId(token)) name = f"rechunk-transfer-{token}" + disk: bool = dask.config.get("distributed.p2p.disk") transfer_keys = [] for index in np.ndindex(tuple(len(dim) for dim in x.chunks)): transfer_keys.append((name,) + index) @@ -184,6 +186,7 @@ def rechunk_p2p(x: da.Array, chunks: ChunkedAxes) -> da.Array: index, chunks, x.chunks, + disk, ) dsk[_barrier_key] = (shuffle_barrier, token, transfer_keys) @@ -258,14 +261,15 @@ def split_axes(old: ChunkedAxes, new: ChunkedAxes) -> SplitAxes: return axes -def convert_chunk(shards: list[tuple[NDIndex, np.ndarray]]) -> np.ndarray: +def convert_chunk(shards: list[list[tuple[NDIndex, np.ndarray]]]) -> np.ndarray: import numpy as np from dask.array.core import concatenate3 indexed: dict[NDIndex, np.ndarray] = {} - for index, shard in shards: - indexed[index] = shard + for sublist in shards: + for index, shard in sublist: + indexed[index] = shard del shards subshape = [max(dim) + 1 for dim in zip(*indexed.keys())] @@ -333,6 +337,7 @@ def __init__( scheduler: PooledRPCCall, memory_limiter_disk: ResourceLimiter, memory_limiter_comms: ResourceLimiter, + disk: bool, ): super().__init__( id=id, @@ -344,6 +349,7 @@ def __init__( scheduler=scheduler, memory_limiter_comms=memory_limiter_comms, memory_limiter_disk=memory_limiter_disk, + disk=disk, ) self.old = old self.new = new @@ -429,13 +435,17 @@ def _(partition_id: NDIndex) -> np.ndarray: return await self.offload(_, partition_id) + def deserialize(self, buffer: bytes) -> Any: + result = pickle.loads(buffer) + return result + def read(self, path: Path) -> tuple[Any, int]: - shards: list[tuple[NDIndex, np.ndarray]] = [] + shards: list[list[tuple[NDIndex, np.ndarray]]] = [] with path.open(mode="rb") as f: size = f.seek(0, os.SEEK_END) f.seek(0) while f.tell() < size: - shards.extend(pickle.load(f)) + shards.append(pickle.load(f)) return shards, size def _get_assigned_worker(self, id: NDIndex) -> str: @@ -475,6 +485,7 @@ def create_run_on_worker( scheduler=plugin.worker.scheduler, memory_limiter_disk=plugin.memory_limiter_disk, memory_limiter_comms=plugin.memory_limiter_comms, + disk=self.disk, ) diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 1bdec302d2c..e9d76dadf05 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -12,6 +12,7 @@ import toolz +import dask from dask.base import tokenize from dask.highlevelgraph import HighLevelGraph from dask.layers import Layer @@ -23,6 +24,7 @@ check_dtype_support, check_minimal_arrow_version, convert_shards, + deserialize_table, list_of_buffers_to_table, read_from_disk, serialize_table, @@ -61,6 +63,7 @@ def shuffle_transfer( column: str, meta: pd.DataFrame, parts_out: set[int], + disk: bool, ) -> int: with handle_transfer_errors(id): return get_worker_plugin().add_partition( @@ -72,6 +75,7 @@ def shuffle_transfer( column=column, meta=meta, parts_out=parts_out, + disk=disk, ), ) @@ -119,6 +123,8 @@ def rearrange_by_column_p2p( ) name = f"shuffle_p2p-{token}" + disk: bool = dask.config.get("distributed.p2p.disk") + layer = P2PShuffleLayer( name, column, @@ -126,6 +132,7 @@ def rearrange_by_column_p2p( npartitions_input=df.npartitions, name_input=df._name, meta_input=meta, + disk=disk, ) return new_dd_object( HighLevelGraph.from_collections(name, layer, [df]), @@ -143,7 +150,9 @@ class P2PShuffleLayer(Layer): column: str npartitions: int npartitions_input: int + name_input: str meta_input: pd.DataFrame + disk: bool parts_out: set[int] def __init__( @@ -154,7 +163,8 @@ def __init__( npartitions_input: int, name_input: str, meta_input: pd.DataFrame, - parts_out: Iterable | None = None, + disk: bool, + parts_out: Iterable[int] | None = None, annotations: dict | None = None, ): check_minimal_arrow_version() @@ -163,6 +173,7 @@ def __init__( self.npartitions = npartitions self.name_input = name_input self.meta_input = meta_input + self.disk = disk if parts_out: self.parts_out = set(parts_out) else: @@ -212,6 +223,7 @@ def _cull(self, parts_out: Iterable[int]) -> P2PShuffleLayer: self.npartitions_input, self.name_input, self.meta_input, + self.disk, parts_out=parts_out, ) @@ -266,6 +278,7 @@ def _construct_graph(self) -> _T_LowLevelGraph: self.column, self.meta_input, self.parts_out, + self.disk, ) dsk[_barrier_key] = (shuffle_barrier, token, transfer_keys) @@ -417,6 +430,7 @@ def __init__( scheduler: PooledRPCCall, memory_limiter_disk: ResourceLimiter, memory_limiter_comms: ResourceLimiter, + disk: bool, ): import pandas as pd @@ -430,6 +444,7 @@ def __init__( scheduler=scheduler, memory_limiter_comms=memory_limiter_comms, memory_limiter_disk=memory_limiter_disk, + disk=disk, ) self.column = column self.meta = meta @@ -512,6 +527,9 @@ def _get_assigned_worker(self, id: int) -> str: def read(self, path: Path) -> tuple[pa.Table, int]: return read_from_disk(path) + def deserialize(self, buffer: bytes) -> Any: + return deserialize_table(buffer) + @dataclass(frozen=True) class DataFrameShuffleSpec(ShuffleSpec[int]): @@ -541,8 +559,11 @@ def create_run_on_worker( local_address=plugin.worker.address, rpc=plugin.worker.rpc, scheduler=plugin.worker.scheduler, - memory_limiter_disk=plugin.memory_limiter_disk, + memory_limiter_disk=plugin.memory_limiter_disk + if self.disk + else ResourceLimiter(None), memory_limiter_comms=plugin.memory_limiter_comms, + disk=self.disk, ) diff --git a/distributed/shuffle/tests/test_memory_buffer.py b/distributed/shuffle/tests/test_memory_buffer.py new file mode 100644 index 00000000000..8182597427c --- /dev/null +++ b/distributed/shuffle/tests/test_memory_buffer.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import pytest + +from distributed.shuffle._memory import MemoryShardsBuffer +from distributed.utils_test import gen_test + + +def deserialize_bytes(buffer: bytes) -> bytes: + return buffer + + +@gen_test() +async def test_basic(): + async with MemoryShardsBuffer(deserialize=deserialize_bytes) as mf: + await mf.write({"x": b"0" * 1000, "y": b"1" * 500}) + await mf.write({"x": b"0" * 1000, "y": b"1" * 500}) + + await mf.flush() + + x = mf.read("x") + y = mf.read("y") + + with pytest.raises(KeyError): + mf.read("z") + + assert x == [b"0" * 1000] * 2 + assert y == [b"1" * 500] * 2 + + +@gen_test() +async def test_read_before_flush(): + payload = {"1": b"foo"} + async with MemoryShardsBuffer(deserialize=deserialize_bytes) as mf: + with pytest.raises(RuntimeError): + mf.read("1") + + await mf.write(payload) + + with pytest.raises(RuntimeError): + mf.read("1") + + await mf.flush() + assert mf.read("1") == [b"foo"] + with pytest.raises(KeyError): + mf.read("2") + + +@pytest.mark.parametrize("count", [2, 100, 1000]) +@gen_test() +async def test_many(count): + async with MemoryShardsBuffer(deserialize=deserialize_bytes) as mf: + d = {str(i): str(i).encode() * 100 for i in range(count)} + + for _ in range(10): + await mf.write(d) + + await mf.flush() + + for i in d: + out = mf.read(str(i)) + assert out == [str(i).encode() * 100] * 10 diff --git a/distributed/shuffle/tests/test_merge.py b/distributed/shuffle/tests/test_merge.py index c2879224abf..fb2341f5618 100644 --- a/distributed/shuffle/tests/test_merge.py +++ b/distributed/shuffle/tests/test_merge.py @@ -12,6 +12,7 @@ dd = pytest.importorskip("dask.dataframe") import pandas as pd +import dask from dask.dataframe._compat import PANDAS_GE_200, tm from dask.dataframe.utils import assert_eq from dask.utils_test import hlg_layer_topological @@ -106,8 +107,9 @@ async def test_basic_merge(c, s, a, b, how, lose_annotations): @pytest.mark.parametrize("how", ["inner", "outer", "left", "right"]) +@pytest.mark.parametrize("disk", [True, False]) @gen_cluster(client=True) -async def test_merge(c, s, a, b, how, lose_annotations): +async def test_merge(c, s, a, b, how, disk, lose_annotations): await invoke_annotation_chaos(lose_annotations, c) A = pd.DataFrame({"x": [1, 2, 3, 4, 5, 6], "y": [1, 1, 2, 2, 3, 4]}) a = dd.repartition(A, [0, 4, 5]) @@ -115,7 +117,10 @@ async def test_merge(c, s, a, b, how, lose_annotations): B = pd.DataFrame({"y": [1, 3, 4, 4, 5, 6], "z": [6, 5, 4, 3, 2, 1]}) b = dd.repartition(B, [0, 2, 5]) - joined = dd.merge(a, b, left_index=True, right_index=True, how=how, shuffle="p2p") + with dask.config.set({"distributed.p2p.disk": disk}): + joined = dd.merge( + a, b, left_index=True, right_index=True, how=how, shuffle="p2p" + ) res = await c.compute(joined) assert_eq( res, diff --git a/distributed/shuffle/tests/test_rechunk.py b/distributed/shuffle/tests/test_rechunk.py index d09e5e97ef9..3cc9cae3081 100644 --- a/distributed/shuffle/tests/test_rechunk.py +++ b/distributed/shuffle/tests/test_rechunk.py @@ -54,6 +54,7 @@ def new_shuffle( new, directory, loop, + disk, Shuffle=ArrayRechunkRun, ): s = Shuffle( @@ -69,6 +70,7 @@ def new_shuffle( scheduler=self, memory_limiter_disk=ResourceLimiter(10000000), memory_limiter_comms=ResourceLimiter(10000000), + disk=disk, ) self.shuffles[name] = s return s @@ -79,9 +81,10 @@ def new_shuffle( @pytest.mark.parametrize("n_workers", [1, 10]) @pytest.mark.parametrize("barrier_first_worker", [True, False]) +@pytest.mark.parametrize("disk", [True, False]) @gen_test() async def test_lowlevel_rechunk( - tmp_path, loop_in_thread, n_workers, barrier_first_worker + tmp_path, loop_in_thread, n_workers, barrier_first_worker, disk ): old = ((1, 2, 3, 4), (5,) * 6) new = ((5, 5), (12, 18)) @@ -113,6 +116,7 @@ async def test_lowlevel_rechunk( new=new, directory=tmp_path, loop=loop_in_thread, + disk=disk, ) ) random.seed(42) @@ -186,8 +190,9 @@ async def test_rechunk_configuration(c, s, *ws, config_value, keyword): assert np.all(await c.compute(x2) == a) +@pytest.mark.parametrize("disk", [True, False]) @gen_cluster(client=True) -async def test_rechunk_2d(c, s, *ws): +async def test_rechunk_2d(c, s, *ws, disk): """Try rechunking a random 2d matrix See Also @@ -197,13 +202,15 @@ async def test_rechunk_2d(c, s, *ws): a = np.random.default_rng().uniform(0, 1, 300).reshape((10, 30)) x = da.from_array(a, chunks=((1, 2, 3, 4), (5,) * 6)) new = ((5, 5), (15,) * 2) - x2 = rechunk(x, chunks=new, method="p2p") + with dask.config.set({"distributed.p2p.disk": disk}): + x2 = rechunk(x, chunks=new, method="p2p") assert x2.chunks == new assert np.all(await c.compute(x2) == a) +@pytest.mark.parametrize("disk", [True, False]) @gen_cluster(client=True) -async def test_rechunk_4d(c, s, *ws): +async def test_rechunk_4d(c, s, *ws, disk): """Try rechunking a random 4d matrix See Also @@ -219,7 +226,8 @@ async def test_rechunk_4d(c, s, *ws): (10,), (8, 2), ) # This has been altered to return >1 output partition - x2 = rechunk(x, chunks=new, method="p2p") + with dask.config.set({"distributed.p2p.disk": disk}): + x2 = rechunk(x, chunks=new, method="p2p") assert x2.chunks == new await c.compute(x2) assert np.all(await c.compute(x2) == a) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 1b4bbce2656..6a521c7a86c 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -183,8 +183,9 @@ def get_active_shuffle_runs(worker: Worker) -> dict[ShuffleId, ShuffleRun]: @pytest.mark.parametrize("npartitions", [None, 1, 20]) +@pytest.mark.parametrize("disk", [True, False]) @gen_cluster(client=True) -async def test_basic_integration(c, s, a, b, lose_annotations, npartitions): +async def test_basic_integration(c, s, a, b, lose_annotations, npartitions, disk): await invoke_annotation_chaos(lose_annotations, c) df = dask.datasets.timeseries( start="2000-01-01", @@ -192,7 +193,8 @@ async def test_basic_integration(c, s, a, b, lose_annotations, npartitions): dtypes={"x": float, "y": float}, freq="10 s", ) - shuffled = dd.shuffle.shuffle(df, "x", shuffle="p2p", npartitions=npartitions) + with dask.config.set({"distributed.p2p.disk": disk}): + shuffled = dd.shuffle.shuffle(df, "x", shuffle="p2p", npartitions=npartitions) if npartitions is None: assert shuffled.npartitions == df.npartitions else: @@ -1566,6 +1568,7 @@ def new_shuffle( worker_for_mapping, directory, loop, + disk, Shuffle=DataFrameShuffleRun, ): s = Shuffle( @@ -1581,6 +1584,7 @@ def new_shuffle( scheduler=self, memory_limiter_disk=ResourceLimiter(10000000), memory_limiter_comms=ResourceLimiter(10000000), + disk=disk, ) self.shuffles[name] = s return s @@ -1592,6 +1596,7 @@ def new_shuffle( @pytest.mark.parametrize("n_input_partitions", [1, 2, 10]) @pytest.mark.parametrize("npartitions", [1, 20]) @pytest.mark.parametrize("barrier_first_worker", [True, False]) +@pytest.mark.parametrize("disk", [True, False]) @gen_test() async def test_basic_lowlevel_shuffle( tmp_path, @@ -1600,6 +1605,7 @@ async def test_basic_lowlevel_shuffle( n_input_partitions, npartitions, barrier_first_worker, + disk, ): pa = pytest.importorskip("pyarrow") @@ -1631,6 +1637,7 @@ async def test_basic_lowlevel_shuffle( worker_for_mapping=worker_for_mapping, directory=tmp_path, loop=loop_in_thread, + disk=disk, ) ) random.seed(42) @@ -1707,6 +1714,7 @@ async def offload(self, func, *args): worker_for_mapping=worker_for_mapping, directory=tmp_path, loop=loop_in_thread, + disk=True, Shuffle=ErrorOffload, ) sB = local_shuffle_pool.new_shuffle( @@ -1715,6 +1723,7 @@ async def offload(self, func, *args): worker_for_mapping=worker_for_mapping, directory=tmp_path, loop=loop_in_thread, + disk=True, ) try: await sB.add_partition(dfs[0], 0) @@ -1761,6 +1770,7 @@ async def send(self, *args: Any, **kwargs: Any) -> None: worker_for_mapping=worker_for_mapping, directory=tmp_path, loop=loop_in_thread, + disk=True, Shuffle=ErrorSend, ) sB = local_shuffle_pool.new_shuffle( @@ -1769,6 +1779,7 @@ async def send(self, *args: Any, **kwargs: Any) -> None: worker_for_mapping=worker_for_mapping, directory=tmp_path, loop=loop_in_thread, + disk=True, ) try: await sA.add_partition(dfs[0], 0) @@ -1814,6 +1825,7 @@ async def receive(self, data: list[tuple[int, bytes]]) -> None: worker_for_mapping=worker_for_mapping, directory=tmp_path, loop=loop_in_thread, + disk=True, Shuffle=ErrorReceive, ) sB = local_shuffle_pool.new_shuffle( @@ -1822,6 +1834,7 @@ async def receive(self, data: list[tuple[int, bytes]]) -> None: worker_for_mapping=worker_for_mapping, directory=tmp_path, loop=loop_in_thread, + disk=True, ) try: await sB.add_partition(dfs[0], 0)