diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 43cdbe5508..c443caa8e9 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1965,6 +1965,7 @@ def _transition( ) v = a_recs.get(key, finish) + # The inner rec has higher priority? Is that always desired? func = self._TRANSITIONS_TABLE["released", v] b_recs, b_cmsgs, b_wmsgs = func(self, key, stimulus_id) @@ -2083,7 +2084,11 @@ def _transition_released_waiting(self, key: Key, stimulus_id: str) -> RecsMsgs: assert not ts.who_has assert not ts.processing_on for dts in ts.dependencies: - assert dts.state not in {"forgotten", "erred"} + assert dts.state not in {"forgotten", "erred"}, ( + str(ts), + str(dts), + self.transition_log, + ) if ts.has_lost_dependencies: return {key: "forgotten"}, {}, {} @@ -2481,7 +2486,9 @@ def _transition_memory_released( recommendations[key] = "forgotten" elif ts.has_lost_dependencies: recommendations[key] = "forgotten" - elif ts.who_wants or ts.waiters: + elif (ts.who_wants or ts.waiters) and not any( + dts.state == "erred" for dts in ts.dependencies + ): recommendations[key] = "waiting" for dts in ts.waiters or (): @@ -2506,14 +2513,13 @@ def _transition_released_erred(self, key: Key, stimulus_id: str) -> RecsMsgs: assert ts.exception_blame assert not ts.who_has assert not ts.waiting_on - assert not ts.waiters failing_ts = ts.exception_blame assert failing_ts for dts in ts.dependents: - dts.exception_blame = failing_ts if not dts.who_has: + dts.exception_blame = failing_ts recommendations[dts.key] = "erred" report_msg = { @@ -2548,6 +2554,9 @@ def _transition_erred_released(self, key: Key, stimulus_id: str) -> RecsMsgs: for dts in ts.dependents: if dts.state == "erred": + # Does this make sense? + # This goes via released + # dts -> released -> waiting recommendations[dts.key] = "waiting" w_msg = { @@ -2622,8 +2631,8 @@ def _transition_processing_erred( self, key: Key, stimulus_id: str, - *, worker: str, + *, cause: Key | None = None, exception: Serialized | None = None, traceback: Serialized | None = None, @@ -2699,7 +2708,7 @@ def _transition_processing_erred( ) ) - for dts in ts.dependents: + for dts in ts.waiters or set(): dts.exception_blame = failing_ts recommendations[dts.key] = "erred" @@ -5040,6 +5049,19 @@ def stimulus_task_finished( "stimulus_id": stimulus_id, } ] + elif ts.state == "erred": + logger.debug( + "Received already erred task, worker: %s" ", key: %s", + worker, + key, + ) + worker_msgs[worker] = [ + { + "op": "free-keys", + "keys": [key], + "stimulus_id": stimulus_id, + } + ] elif ts.run_id != run_id: if not ts.processing_on or ts.processing_on.address != worker: logger.debug( diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 652997a2f3..5dc6153860 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -4890,3 +4890,139 @@ async def test_resubmit_different_task_same_key_warns_only_once( async with Worker(s.address): assert await c.gather(zs) == [2, 3, 4] # Kept old ys + + +def block(x, in_event, block_event): + in_event.set() + block_event.wait() + return x + + +@gen_cluster( + client=True, + nthreads=[("", 1, {"resources": {"a": 1}})], + config={"distributed.scheduler.allowed-failures": 0}, +) +async def test_fan_out_pattern_deadlock(c, s, a): + """Regression test for https://github.com/dask/distributed/issues/8548 + + This test heavily uses resources to force scheduling decisions. + """ + in_f, block_f = Event(), Event() + in_ha, block_ha = Event(), Event() + in_hb, block_hb = Event(), Event() + + # Input task to 'g' that we can fail + with dask.annotate(resources={"b": 1}): + f = delayed(block)(1, in_f, block_f, dask_key_name="f") + g = delayed(inc)(f, dask_key_name="g") + + # Fan-out from 'g' and run h1 and h2 on different workers + hb = delayed(block)(g, in_hb, block_hb, dask_key_name="hb") + with dask.annotate(resources={"a": 1}): + ha = delayed(block)(g, in_ha, block_ha, dask_key_name="ha") + + f, ha, hb = c.compute([f, ha, hb]) + with captured_logger("distributed.scheduler", level=logging.ERROR) as logger: + async with Worker(s.address, nthreads=1, resources={"b": 1}) as b: + await block_f.set() + await in_ha.wait() + await in_hb.wait() + await in_f.clear() + + # Make sure that the scheduler knows that both workers hold 'g' in memory + await async_poll_for(lambda: len(s.tasks["g"].who_has) == 2, timeout=5) + # Remove worker 'b' while it's processing h1 + await s.remove_worker(b.address, stimulus_id="remove_b1") + await block_hb.set() + await block_f.clear() + + # Remove the new instance of the 'b' worker while it processes 'f' + # to trigger an transition for 'f' to 'erred' + async with Worker(s.address, nthreads=1, resources={"b": 1}) as b: + await in_f.wait() + await in_f.clear() + await s.remove_worker(b.address, stimulus_id="remove_b2") + await block_f.set() + await block_f.clear() + + await block_ha.set() + await ha + + with pytest.raises(KilledWorker, match="Attempted to run task 'hb'"): + await hb + + del ha, hb + # Make sure that h2 gets forgotten on worker 'a' + await async_poll_for(lambda: not a.state.tasks, timeout=5) + # Ensure that no other errors including transition failures were logged + assert ( + logger.getvalue() + == "Task hb marked as failed because 1 workers died while trying to run it\nTask f marked as failed because 1 workers died while trying to run it\n" + ) + + +@gen_cluster( + client=True, + nthreads=[("", 1, {"resources": {"a": 1}})], + config={"distributed.scheduler.allowed-failures": 0}, +) +async def test_stimulus_from_erred_task(c, s, a): + """This test heavily uses resources to force scheduling decisions.""" + in_f, block_f = Event(), Event() + in_g, block_g = Event(), Event() + + with dask.annotate(resources={"b": 1}): + f = delayed(block)(1, in_f, block_f, dask_key_name="f") + + with dask.annotate(resources={"a": 1}): + g = delayed(block)(f, in_g, block_g, dask_key_name="g") + + f, g = c.compute([f, g]) + with captured_logger("distributed.scheduler", level=logging.ERROR) as logger: + frozen_stream_from_a_ctx = freeze_batched_send(a.batched_stream) + frozen_stream_from_a_ctx.__enter__() + + async with Worker(s.address, nthreads=1, resources={"b": 1}) as b1: + await block_f.set() + await in_g.wait() + await in_f.clear() + frozen_stream_to_a_ctx = freeze_batched_send(s.stream_comms[a.address]) + frozen_stream_to_a_ctx.__enter__() + await s.remove_worker(b1.address, stimulus_id="remove_b1") + await block_f.clear() + + # Remove the new instance of the 'b' worker while it processes 'f' + # to trigger a transition for 'f' to 'erred' + async with Worker(s.address, nthreads=1, resources={"b": 1}) as b2: + await in_f.wait() + await in_f.clear() + await s.remove_worker(b2.address, stimulus_id="remove_b2") + await block_f.set() + + with pytest.raises(KilledWorker, match="Attempted to run task 'f'"): + await f + + # g has already been transitioned to 'erred' because 'f' failed + with pytest.raises(KilledWorker, match="Attempted to run task 'f'"): + await g + + # Finish 'g' and let the scheduler know so it can trigger cleanup + await block_g.set() + with mock.patch.object( + s, "stimulus_task_finished", wraps=s.stimulus_task_finished + ) as wrapped_stimulus: + frozen_stream_from_a_ctx.__exit__(None, None, None) + # Make sure the `stimulus_task_finished` gets processed + await async_poll_for(lambda: wrapped_stimulus.call_count == 1, timeout=5) + + # Allow the scheduler to talk to the worker again + frozen_stream_to_a_ctx.__exit__(None, None, None) + # Make sure all data gets forgotten on worker 'a' + await async_poll_for(lambda: not a.state.tasks, timeout=5) + + # Ensure that no other errors including transition failures were logged + assert ( + logger.getvalue() + == "Task f marked as failed because 1 workers died while trying to run it\n" + )