Skip to content

Commit

Permalink
Deduplicate requests to scheduler in P2P (dask#8899)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Oct 17, 2024
1 parent 42e34e3 commit 48509b3
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 33 deletions.
47 changes: 26 additions & 21 deletions distributed/shuffle/_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import logging
from collections import defaultdict
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, overload
Expand Down Expand Up @@ -39,6 +40,7 @@ class _ShuffleRunManager:
closed: bool
_active_runs: dict[ShuffleId, ShuffleRun]
_runs: set[ShuffleRun]
_refresh_locks: defaultdict[ShuffleId, asyncio.Lock]
#: Mapping of shuffle IDs to the largest stale run ID.
#: This is used to prevent race conditions between fetching shuffle run data
#: from the scheduler and failing a shuffle run.
Expand All @@ -51,6 +53,7 @@ def __init__(self, plugin: ShuffleWorkerPlugin) -> None:
self.closed = False
self._active_runs = {}
self._runs = set()
self._refresh_locks = defaultdict(asyncio.Lock)
self._stale_run_ids = {}
self._runs_cleanup_condition = asyncio.Condition()
self._plugin = plugin
Expand Down Expand Up @@ -117,20 +120,21 @@ async def get_with_run_id(self, shuffle_id: ShuffleId, run_id: int) -> ShuffleRu
ShuffleClosedError
If the run manager has been closed
"""
shuffle_run = self._active_runs.get(shuffle_id, None)
if shuffle_run is None or shuffle_run.run_id < run_id:
shuffle_run = await self._refresh(shuffle_id=shuffle_id)

if shuffle_run.run_id > run_id:
raise P2PConsistencyError(f"{run_id=} stale, got {shuffle_run}")
elif shuffle_run.run_id < run_id:
raise P2PConsistencyError(f"{run_id=} invalid, got {shuffle_run}")

if self.closed:
raise ShuffleClosedError(f"{self} has already been closed")
if shuffle_run._exception:
raise shuffle_run._exception
return shuffle_run
async with self._refresh_locks[shuffle_id]:
shuffle_run = self._active_runs.get(shuffle_id, None)
if shuffle_run is None or shuffle_run.run_id < run_id:
shuffle_run = await self._refresh(shuffle_id=shuffle_id)

if shuffle_run.run_id > run_id:
raise P2PConsistencyError(f"{run_id=} stale, got {shuffle_run}")
elif shuffle_run.run_id < run_id:
raise P2PConsistencyError(f"{run_id=} invalid, got {shuffle_run}")

if self.closed:
raise ShuffleClosedError(f"{self} has already been closed")
if shuffle_run._exception:
raise shuffle_run._exception
return shuffle_run

async def get_or_create(self, spec: ShuffleSpec, key: Key) -> ShuffleRun:
"""Get or create a shuffle matching the ID and data spec.
Expand All @@ -144,13 +148,14 @@ async def get_or_create(self, spec: ShuffleSpec, key: Key) -> ShuffleRun:
key:
Task key triggering the function
"""
shuffle_run = self._active_runs.get(spec.id, None)
if shuffle_run is None:
shuffle_run = await self._refresh(
shuffle_id=spec.id,
spec=spec,
key=key,
)
async with self._refresh_locks[spec.id]:
shuffle_run = self._active_runs.get(spec.id, None)
if shuffle_run is None:
shuffle_run = await self._refresh(
shuffle_id=spec.id,
spec=spec,
key=key,
)

if self.closed:
raise ShuffleClosedError(f"{self} has already been closed")
Expand Down
95 changes: 83 additions & 12 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2704,18 +2704,6 @@ async def test_unpack_gets_rescheduled_from_non_participating_worker(c, s, a):
dd.assert_eq(result, expected)


class BlockedBarrierShuffleSchedulerPlugin(ShuffleSchedulerPlugin):
def __init__(self, scheduler: Scheduler):
super().__init__(scheduler)
self.in_barrier = asyncio.Event()
self.block_barrier = asyncio.Event()

async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
self.in_barrier.set()
await self.block_barrier.wait()
return await super().barrier(id, run_id, consistent)


class FlakyConnectionPool(ConnectionPool):
def __init__(self, *args, failing_connects=0, **kwargs):
self.attempts = 0
Expand Down Expand Up @@ -2955,3 +2943,86 @@ async def test_dont_downscale_participating_workers(c, s, a, b):

workers_to_close = s.workers_to_close(n=2)
assert len(workers_to_close) == 2


class RequestCountingSchedulerPlugin(ShuffleSchedulerPlugin):
def __init__(self, scheduler):
super().__init__(scheduler)
self.counts = defaultdict(int)

def get(self, *args, **kwargs):
self.counts["get"] += 1
return super().get(*args, **kwargs)

def get_or_create(self, *args, **kwargs):
self.counts["get_or_create"] += 1
return super().get_or_create(*args, **kwargs)


class PostFetchBlockingManager(_ShuffleRunManager):
def __init__(self, plugin):
super().__init__(plugin)
self.in_fetch = asyncio.Event()
self.block_fetch = asyncio.Event()

async def _fetch(self, *args, **kwargs):
result = await super()._fetch(*args, **kwargs)
self.in_fetch.set()
await self.block_fetch.wait()
return result


@mock.patch(
"distributed.shuffle.ShuffleSchedulerPlugin",
RequestCountingSchedulerPlugin,
)
@mock.patch(
"distributed.shuffle._worker_plugin._ShuffleRunManager",
PostFetchBlockingManager,
)
@gen_cluster(
client=True,
nthreads=[("", 2)] * 2,
config={
"distributed.scheduler.allowed-failures": 0,
"distributed.p2p.comm.message-size-limit": "10 B",
},
)
async def test_workers_do_not_spam_get_requests(c, s, a, b):
df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-02-01",
dtypes={"x": float, "y": float},
freq="10 s",
)
s.remove_plugin("shuffle")
shuffle_extS = RequestCountingSchedulerPlugin(s)
shuffle_extA = a.plugins["shuffle"]
shuffle_extB = b.plugins["shuffle"]

with dask.config.set({"dataframe.shuffle.method": "p2p"}):
out = df.shuffle("x", npartitions=100)
out = c.compute(out.x.size)

shuffle_id = await wait_until_new_shuffle_is_initialized(s)
key = barrier_key(shuffle_id)
await shuffle_extA.shuffle_runs.in_fetch.wait()
await shuffle_extB.shuffle_runs.in_fetch.wait()

shuffle_extA.shuffle_runs.block_fetch.set()

barrier_task = s.tasks[key]
while any(
ts.state not in ("processing", "memory") for ts in barrier_task.dependencies
):
await asyncio.sleep(0.1)
shuffle_extB.shuffle_runs.block_fetch.set()
await out

assert sum(shuffle_extS.counts.values()) == 2

del out

await assert_worker_cleanup(a)
await assert_worker_cleanup(b)
await assert_scheduler_cleanup(s)

0 comments on commit 48509b3

Please sign in to comment.