Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deduplicate scheduler requests in P2P #8899

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 26 additions & 21 deletions distributed/shuffle/_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import logging
from collections import defaultdict
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, overload
Expand Down Expand Up @@ -39,6 +40,7 @@ class _ShuffleRunManager:
closed: bool
_active_runs: dict[ShuffleId, ShuffleRun]
_runs: set[ShuffleRun]
_refresh_locks: defaultdict[ShuffleId, asyncio.Lock]
#: Mapping of shuffle IDs to the largest stale run ID.
#: This is used to prevent race conditions between fetching shuffle run data
#: from the scheduler and failing a shuffle run.
Expand All @@ -51,6 +53,7 @@ def __init__(self, plugin: ShuffleWorkerPlugin) -> None:
self.closed = False
self._active_runs = {}
self._runs = set()
self._refresh_locks = defaultdict(asyncio.Lock)
self._stale_run_ids = {}
self._runs_cleanup_condition = asyncio.Condition()
self._plugin = plugin
Expand Down Expand Up @@ -117,20 +120,21 @@ async def get_with_run_id(self, shuffle_id: ShuffleId, run_id: int) -> ShuffleRu
ShuffleClosedError
If the run manager has been closed
"""
shuffle_run = self._active_runs.get(shuffle_id, None)
if shuffle_run is None or shuffle_run.run_id < run_id:
shuffle_run = await self._refresh(shuffle_id=shuffle_id)

if shuffle_run.run_id > run_id:
raise P2PConsistencyError(f"{run_id=} stale, got {shuffle_run}")
elif shuffle_run.run_id < run_id:
raise P2PConsistencyError(f"{run_id=} invalid, got {shuffle_run}")

if self.closed:
raise ShuffleClosedError(f"{self} has already been closed")
if shuffle_run._exception:
raise shuffle_run._exception
return shuffle_run
async with self._refresh_locks[shuffle_id]:
shuffle_run = self._active_runs.get(shuffle_id, None)
if shuffle_run is None or shuffle_run.run_id < run_id:
shuffle_run = await self._refresh(shuffle_id=shuffle_id)

if shuffle_run.run_id > run_id:
raise P2PConsistencyError(f"{run_id=} stale, got {shuffle_run}")
elif shuffle_run.run_id < run_id:
raise P2PConsistencyError(f"{run_id=} invalid, got {shuffle_run}")

if self.closed:
raise ShuffleClosedError(f"{self} has already been closed")
if shuffle_run._exception:
raise shuffle_run._exception
return shuffle_run

async def get_or_create(self, spec: ShuffleSpec, key: Key) -> ShuffleRun:
"""Get or create a shuffle matching the ID and data spec.
Expand All @@ -144,13 +148,14 @@ async def get_or_create(self, spec: ShuffleSpec, key: Key) -> ShuffleRun:
key:
Task key triggering the function
"""
shuffle_run = self._active_runs.get(spec.id, None)
if shuffle_run is None:
shuffle_run = await self._refresh(
shuffle_id=spec.id,
spec=spec,
key=key,
)
async with self._refresh_locks[spec.id]:
shuffle_run = self._active_runs.get(spec.id, None)
if shuffle_run is None:
shuffle_run = await self._refresh(
shuffle_id=spec.id,
spec=spec,
key=key,
)

if self.closed:
raise ShuffleClosedError(f"{self} has already been closed")
Expand Down
95 changes: 83 additions & 12 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2704,18 +2704,6 @@ async def test_unpack_gets_rescheduled_from_non_participating_worker(c, s, a):
dd.assert_eq(result, expected)


class BlockedBarrierShuffleSchedulerPlugin(ShuffleSchedulerPlugin):
def __init__(self, scheduler: Scheduler):
super().__init__(scheduler)
self.in_barrier = asyncio.Event()
self.block_barrier = asyncio.Event()

async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
self.in_barrier.set()
await self.block_barrier.wait()
return await super().barrier(id, run_id, consistent)


class FlakyConnectionPool(ConnectionPool):
def __init__(self, *args, failing_connects=0, **kwargs):
self.attempts = 0
Expand Down Expand Up @@ -2955,3 +2943,86 @@ async def test_dont_downscale_participating_workers(c, s, a, b):

workers_to_close = s.workers_to_close(n=2)
assert len(workers_to_close) == 2


class RequestCountingSchedulerPlugin(ShuffleSchedulerPlugin):
def __init__(self, scheduler):
super().__init__(scheduler)
self.counts = defaultdict(int)

def get(self, *args, **kwargs):
self.counts["get"] += 1
return super().get(*args, **kwargs)

def get_or_create(self, *args, **kwargs):
self.counts["get_or_create"] += 1
return super().get_or_create(*args, **kwargs)


class PostFetchBlockingManager(_ShuffleRunManager):
def __init__(self, plugin):
super().__init__(plugin)
self.in_fetch = asyncio.Event()
self.block_fetch = asyncio.Event()

async def _fetch(self, *args, **kwargs):
result = await super()._fetch(*args, **kwargs)
self.in_fetch.set()
await self.block_fetch.wait()
return result


@mock.patch(
"distributed.shuffle.ShuffleSchedulerPlugin",
RequestCountingSchedulerPlugin,
)
@mock.patch(
"distributed.shuffle._worker_plugin._ShuffleRunManager",
PostFetchBlockingManager,
)
@gen_cluster(
client=True,
nthreads=[("", 2)] * 2,
config={
"distributed.scheduler.allowed-failures": 0,
"distributed.p2p.comm.message-size-limit": "10 B",
},
)
async def test_workers_do_not_spam_get_requests(c, s, a, b):
df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-02-01",
dtypes={"x": float, "y": float},
freq="10 s",
)
s.remove_plugin("shuffle")
shuffle_extS = RequestCountingSchedulerPlugin(s)
shuffle_extA = a.plugins["shuffle"]
shuffle_extB = b.plugins["shuffle"]

with dask.config.set({"dataframe.shuffle.method": "p2p"}):
out = df.shuffle("x", npartitions=100)
out = c.compute(out.x.size)

shuffle_id = await wait_until_new_shuffle_is_initialized(s)
key = barrier_key(shuffle_id)
await shuffle_extA.shuffle_runs.in_fetch.wait()
await shuffle_extB.shuffle_runs.in_fetch.wait()

shuffle_extA.shuffle_runs.block_fetch.set()

barrier_task = s.tasks[key]
while any(
ts.state not in ("processing", "memory") for ts in barrier_task.dependencies
):
await asyncio.sleep(0.1)
shuffle_extB.shuffle_runs.block_fetch.set()
await out

assert sum(shuffle_extS.counts.values()) == 2

del out

await assert_worker_cleanup(a)
await assert_worker_cleanup(b)
await assert_scheduler_cleanup(s)
Loading