Skip to content

Commit

Permalink
Fix test_restarting_does_not_deadlock (#8849)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Oct 29, 2024
1 parent 09ed8af commit 2953090
Showing 1 changed file with 69 additions and 42 deletions.
111 changes: 69 additions & 42 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,23 +326,42 @@ async def test_bad_disk(c, s, a, b):
await assert_scheduler_cleanup(s)


async def wait_until_worker_has_tasks(
prefix: str, worker: str, count: int, scheduler: Scheduler, interval: float = 0.01
) -> None:
ws = scheduler.workers[worker]
while (
len(
[
key
for key, ts in scheduler.tasks.items()
if prefix in key_split(key)
and ts.state == "memory"
and {ws} == ts.who_has
]
)
< count
):
await asyncio.sleep(interval)
from distributed.diagnostics.plugin import SchedulerPlugin


class ObserveTasksPlugin(SchedulerPlugin):
def __init__(self, prefixes, count, worker):
self.prefixes = prefixes
self.count = count
self.worker = worker
self.counter = defaultdict(int)
self.event = asyncio.Event()

async def start(self, scheduler):
self.scheduler = scheduler

def transition(self, key, start, finish, *args, **kwargs):
if (
finish == "processing"
and key_split(key) in self.prefixes
and self.scheduler.tasks[key].processing_on
and self.scheduler.tasks[key].processing_on.address == self.worker
):
self.counter[key_split(key)] += 1
if self.counter[key_split(key)] == self.count:
self.event.set()
return key, start, finish


@contextlib.asynccontextmanager
async def wait_until_worker_has_tasks(prefix, worker, count, scheduler):
plugin = ObserveTasksPlugin([prefix], count, worker)
scheduler.add_plugin(plugin, name="observe-tasks")
await plugin.start(scheduler)
try:
yield plugin.event
finally:
scheduler.remove_plugin("observe-tasks")


async def wait_for_tasks_in_state(
Expand Down Expand Up @@ -562,8 +581,12 @@ async def test_get_or_create_from_dangling_transfer(c, s, a, b):
@pytest.mark.slow
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_crashed_worker_during_transfer(c, s, a):
async with Nanny(s.address, nthreads=1) as n:
killed_worker_address = n.worker_address
async with (
Nanny(s.address, nthreads=1) as n,
wait_until_worker_has_tasks(
"shuffle-transfer", n.worker_address, 1, s
) as event,
):
df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-03-01",
Expand All @@ -573,9 +596,7 @@ async def test_crashed_worker_during_transfer(c, s, a):
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
shuffled = df.shuffle("x")
fut = c.compute([shuffled, df], sync=True)
await wait_until_worker_has_tasks(
"shuffle-transfer", killed_worker_address, 1, s
)
await event.wait()
await n.process.process.kill()

result, expected = await fut
Expand Down Expand Up @@ -605,20 +626,16 @@ async def test_restarting_does_not_deadlock(c, s):
)
df = await c.persist(df)
expected = await c.compute(df)

async with Nanny(s.address) as b:
async with Worker(s.address) as b:
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
out = df.shuffle("x")
assert not s.workers[b.worker_address].has_what
result = c.compute(out)
await wait_until_worker_has_tasks(
"shuffle-transfer", b.worker_address, 1, s
)
while not s.extensions["shuffle"].active_shuffles:
await asyncio.sleep(0)
a.status = Status.paused
await async_poll_for(lambda: len(s.running) == 1, timeout=5)
b.close_gracefully()
await b.process.process.kill()

b.batched_stream.close()
await async_poll_for(lambda: not s.running, timeout=5)

a.status = Status.running
Expand Down Expand Up @@ -672,8 +689,12 @@ def mock_mock_get_worker_for_range_sharding(
"distributed.shuffle._shuffle._get_worker_for_range_sharding",
mock_mock_get_worker_for_range_sharding,
):
async with Nanny(s.address, nthreads=1) as n:
killed_worker_address = n.worker_address
async with (
Nanny(s.address, nthreads=1) as n,
wait_until_worker_has_tasks(
"shuffle-transfer", n.worker_address, 1, s
) as event,
):
df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-03-01",
Expand All @@ -683,9 +704,7 @@ def mock_mock_get_worker_for_range_sharding(
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
shuffled = df.shuffle("x")
fut = c.compute([shuffled, df], sync=True)
await wait_until_worker_has_tasks(
"shuffle-transfer", n.worker_address, 1, s
)
await event.wait()
await n.process.process.kill()

result, expected = await fut
Expand Down Expand Up @@ -1033,8 +1052,10 @@ async def test_restarting_during_unpack_raises_killed_worker(c, s, a, b):
@pytest.mark.slow
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_crashed_worker_during_unpack(c, s, a):
async with Nanny(s.address, nthreads=2) as n:
killed_worker_address = n.worker_address
async with (
Nanny(s.address, nthreads=2) as n,
wait_until_worker_has_tasks(UNPACK_PREFIX, n.worker_address, 1, s) as event,
):
df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-03-01",
Expand All @@ -1046,7 +1067,7 @@ async def test_crashed_worker_during_unpack(c, s, a):
shuffled = df.shuffle("x")
result = c.compute(shuffled)

await wait_until_worker_has_tasks(UNPACK_PREFIX, killed_worker_address, 1, s)
await event.wait()
await n.process.process.kill()

result = await result
Expand Down Expand Up @@ -1486,7 +1507,10 @@ def block(df, in_event, block_event):
block_event.wait()
return df

async with Nanny(s.address, nthreads=1) as n:
async with (
Nanny(s.address, nthreads=1) as n,
wait_until_worker_has_tasks(UNPACK_PREFIX, n.worker_address, 1, s) as event,
):
df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-03-01",
Expand All @@ -1507,7 +1531,7 @@ def block(df, in_event, block_event):
allow_other_workers=True,
)

await wait_until_worker_has_tasks(UNPACK_PREFIX, n.worker_address, 1, s)
await event.wait()
await in_event.wait()
await n.process.process.kill()
await block_event.set()
Expand All @@ -1524,7 +1548,10 @@ def block(df, in_event, block_event):

@gen_cluster(client=True, nthreads=[("", 1)])
async def test_crashed_worker_after_shuffle_persisted(c, s, a):
async with Nanny(s.address, nthreads=1) as n:
async with (
Nanny(s.address, nthreads=1) as n,
wait_until_worker_has_tasks(UNPACK_PREFIX, n.worker_address, 1, s) as event,
):
df = df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-01-10",
Expand All @@ -1536,7 +1563,7 @@ async def test_crashed_worker_after_shuffle_persisted(c, s, a):
out = df.shuffle("x")
out = out.persist()

await wait_until_worker_has_tasks(UNPACK_PREFIX, n.worker_address, 1, s)
await event.wait()
await out

await n.process.process.kill()
Expand Down

0 comments on commit 2953090

Please sign in to comment.