From 29530909fecedeead25b98198ed3ab28d185f6cc Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Tue, 29 Oct 2024 10:54:46 +0100 Subject: [PATCH] Fix test_restarting_does_not_deadlock (#8849) --- distributed/shuffle/tests/test_shuffle.py | 111 ++++++++++++++-------- 1 file changed, 69 insertions(+), 42 deletions(-) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index bf2ab5bc09..17d55c8d4e 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -326,23 +326,42 @@ async def test_bad_disk(c, s, a, b): await assert_scheduler_cleanup(s) -async def wait_until_worker_has_tasks( - prefix: str, worker: str, count: int, scheduler: Scheduler, interval: float = 0.01 -) -> None: - ws = scheduler.workers[worker] - while ( - len( - [ - key - for key, ts in scheduler.tasks.items() - if prefix in key_split(key) - and ts.state == "memory" - and {ws} == ts.who_has - ] - ) - < count - ): - await asyncio.sleep(interval) +from distributed.diagnostics.plugin import SchedulerPlugin + + +class ObserveTasksPlugin(SchedulerPlugin): + def __init__(self, prefixes, count, worker): + self.prefixes = prefixes + self.count = count + self.worker = worker + self.counter = defaultdict(int) + self.event = asyncio.Event() + + async def start(self, scheduler): + self.scheduler = scheduler + + def transition(self, key, start, finish, *args, **kwargs): + if ( + finish == "processing" + and key_split(key) in self.prefixes + and self.scheduler.tasks[key].processing_on + and self.scheduler.tasks[key].processing_on.address == self.worker + ): + self.counter[key_split(key)] += 1 + if self.counter[key_split(key)] == self.count: + self.event.set() + return key, start, finish + + +@contextlib.asynccontextmanager +async def wait_until_worker_has_tasks(prefix, worker, count, scheduler): + plugin = ObserveTasksPlugin([prefix], count, worker) + scheduler.add_plugin(plugin, name="observe-tasks") + await plugin.start(scheduler) + try: + yield plugin.event + finally: + scheduler.remove_plugin("observe-tasks") async def wait_for_tasks_in_state( @@ -562,8 +581,12 @@ async def test_get_or_create_from_dangling_transfer(c, s, a, b): @pytest.mark.slow @gen_cluster(client=True, nthreads=[("", 1)]) async def test_crashed_worker_during_transfer(c, s, a): - async with Nanny(s.address, nthreads=1) as n: - killed_worker_address = n.worker_address + async with ( + Nanny(s.address, nthreads=1) as n, + wait_until_worker_has_tasks( + "shuffle-transfer", n.worker_address, 1, s + ) as event, + ): df = dask.datasets.timeseries( start="2000-01-01", end="2000-03-01", @@ -573,9 +596,7 @@ async def test_crashed_worker_during_transfer(c, s, a): with dask.config.set({"dataframe.shuffle.method": "p2p"}): shuffled = df.shuffle("x") fut = c.compute([shuffled, df], sync=True) - await wait_until_worker_has_tasks( - "shuffle-transfer", killed_worker_address, 1, s - ) + await event.wait() await n.process.process.kill() result, expected = await fut @@ -605,20 +626,16 @@ async def test_restarting_does_not_deadlock(c, s): ) df = await c.persist(df) expected = await c.compute(df) - - async with Nanny(s.address) as b: + async with Worker(s.address) as b: with dask.config.set({"dataframe.shuffle.method": "p2p"}): out = df.shuffle("x") assert not s.workers[b.worker_address].has_what result = c.compute(out) - await wait_until_worker_has_tasks( - "shuffle-transfer", b.worker_address, 1, s - ) + while not s.extensions["shuffle"].active_shuffles: + await asyncio.sleep(0) a.status = Status.paused await async_poll_for(lambda: len(s.running) == 1, timeout=5) - b.close_gracefully() - await b.process.process.kill() - + b.batched_stream.close() await async_poll_for(lambda: not s.running, timeout=5) a.status = Status.running @@ -672,8 +689,12 @@ def mock_mock_get_worker_for_range_sharding( "distributed.shuffle._shuffle._get_worker_for_range_sharding", mock_mock_get_worker_for_range_sharding, ): - async with Nanny(s.address, nthreads=1) as n: - killed_worker_address = n.worker_address + async with ( + Nanny(s.address, nthreads=1) as n, + wait_until_worker_has_tasks( + "shuffle-transfer", n.worker_address, 1, s + ) as event, + ): df = dask.datasets.timeseries( start="2000-01-01", end="2000-03-01", @@ -683,9 +704,7 @@ def mock_mock_get_worker_for_range_sharding( with dask.config.set({"dataframe.shuffle.method": "p2p"}): shuffled = df.shuffle("x") fut = c.compute([shuffled, df], sync=True) - await wait_until_worker_has_tasks( - "shuffle-transfer", n.worker_address, 1, s - ) + await event.wait() await n.process.process.kill() result, expected = await fut @@ -1033,8 +1052,10 @@ async def test_restarting_during_unpack_raises_killed_worker(c, s, a, b): @pytest.mark.slow @gen_cluster(client=True, nthreads=[("", 1)]) async def test_crashed_worker_during_unpack(c, s, a): - async with Nanny(s.address, nthreads=2) as n: - killed_worker_address = n.worker_address + async with ( + Nanny(s.address, nthreads=2) as n, + wait_until_worker_has_tasks(UNPACK_PREFIX, n.worker_address, 1, s) as event, + ): df = dask.datasets.timeseries( start="2000-01-01", end="2000-03-01", @@ -1046,7 +1067,7 @@ async def test_crashed_worker_during_unpack(c, s, a): shuffled = df.shuffle("x") result = c.compute(shuffled) - await wait_until_worker_has_tasks(UNPACK_PREFIX, killed_worker_address, 1, s) + await event.wait() await n.process.process.kill() result = await result @@ -1486,7 +1507,10 @@ def block(df, in_event, block_event): block_event.wait() return df - async with Nanny(s.address, nthreads=1) as n: + async with ( + Nanny(s.address, nthreads=1) as n, + wait_until_worker_has_tasks(UNPACK_PREFIX, n.worker_address, 1, s) as event, + ): df = dask.datasets.timeseries( start="2000-01-01", end="2000-03-01", @@ -1507,7 +1531,7 @@ def block(df, in_event, block_event): allow_other_workers=True, ) - await wait_until_worker_has_tasks(UNPACK_PREFIX, n.worker_address, 1, s) + await event.wait() await in_event.wait() await n.process.process.kill() await block_event.set() @@ -1524,7 +1548,10 @@ def block(df, in_event, block_event): @gen_cluster(client=True, nthreads=[("", 1)]) async def test_crashed_worker_after_shuffle_persisted(c, s, a): - async with Nanny(s.address, nthreads=1) as n: + async with ( + Nanny(s.address, nthreads=1) as n, + wait_until_worker_has_tasks(UNPACK_PREFIX, n.worker_address, 1, s) as event, + ): df = df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -1536,7 +1563,7 @@ async def test_crashed_worker_after_shuffle_persisted(c, s, a): out = df.shuffle("x") out = out.persist() - await wait_until_worker_has_tasks(UNPACK_PREFIX, n.worker_address, 1, s) + await event.wait() await out await n.process.process.kill()