Skip to content

Commit

Permalink
SpecCluster resilience to broken workers (#8233)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Oct 6, 2023
1 parent de3f755 commit 9a8b380
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 44 deletions.
26 changes: 15 additions & 11 deletions distributed/deploy/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,13 +379,15 @@ async def _correct_state_internal(self) -> None:
self._created.add(worker)
workers.append(worker)
if workers:
await asyncio.wait(
[asyncio.create_task(_wrap_awaitable(w)) for w in workers]
)
worker_futs = [asyncio.ensure_future(w) for w in workers]
await asyncio.wait(worker_futs)
self.workers.update(dict(zip(to_open, workers)))
for w in workers:
w._cluster = weakref.ref(self)
await w # for tornado gen.coroutine support
self.workers.update(dict(zip(to_open, workers)))
# Collect exceptions from failed workers. This must happen after all
# *other* workers have finished initialising, so that we can have a
# proper teardown.
await asyncio.gather(*worker_futs)

def _update_worker_status(self, op, msg):
if op == "remove":
Expand Down Expand Up @@ -467,10 +469,14 @@ async def _close(self):
await super()._close()

async def __aenter__(self):
await self
await self._correct_state()
assert self.status == Status.running
return self
try:
await self
await self._correct_state()
assert self.status == Status.running
return self
except Exception:
await self.close()
raise

def _threads_per_worker(self) -> int:
"""Return the number of threads per worker for new workers"""
Expand Down Expand Up @@ -678,8 +684,6 @@ async def run_spec(spec: dict[str, Any], *args: Any) -> dict[str, Worker | Nanny

if workers:
await asyncio.gather(*workers.values())
for w in workers.values():
await w # for tornado gen.coroutine support
return workers


Expand Down
31 changes: 20 additions & 11 deletions distributed/deploy/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,16 +1141,18 @@ async def test_local_cluster_redundant_kwarg(nanny):
dashboard_address=":0",
asynchronous=True,
)
try:
with pytest.raises(TypeError, match="unexpected keyword argument"):
# Extra arguments are forwarded to the worker class. Depending on
# whether we use the nanny or not, the error treatment is quite
# different and we should assert that an exception is raised
async with cluster:
pass
finally:
# FIXME: LocalCluster leaks if LocalCluster.__aenter__ raises
await cluster.close()
if nanny:
ctx = raises_with_cause(
RuntimeError, None, TypeError, "unexpected keyword argument"
)
else:
ctx = pytest.raises(TypeError, match="unexpected keyword argument")
with ctx:
# Extra arguments are forwarded to the worker class. Depending on
# whether we use the nanny or not, the error treatment is quite
# different and we should assert that an exception is raised
async with cluster:
pass


@gen_test()
Expand Down Expand Up @@ -1255,7 +1257,14 @@ def setup(self, worker=None):

@pytest.mark.slow
def test_localcluster_start_exception(loop):
with raises_with_cause(RuntimeError, None, ImportError, "my_nonexistent_library"):
with raises_with_cause(
RuntimeError,
"Nanny failed to start",
RuntimeError,
"Worker failed to start",
ImportError,
"my_nonexistent_library",
):
with LocalCluster(
n_workers=1,
threads_per_worker=1,
Expand Down
12 changes: 3 additions & 9 deletions distributed/deploy/tests/test_spec_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ async def test_restart():
await asyncio.sleep(0.01)


@pytest.mark.skipif(WINDOWS, reason="HTTP Server doesn't close out")
@gen_test()
async def test_broken_worker():
class BrokenWorkerException(Exception):
Expand All @@ -216,7 +215,6 @@ class BrokenWorkerException(Exception):
class BrokenWorker(Worker):
def __await__(self):
async def _():
self.status = Status.closed
raise BrokenWorkerException("Worker Broken")

return _().__await__()
Expand All @@ -226,13 +224,9 @@ async def _():
workers={"good": {"cls": Worker}, "bad": {"cls": BrokenWorker}},
scheduler=scheduler,
)
try:
with pytest.raises(BrokenWorkerException, match=r"Worker Broken"):
async with cluster:
pass
finally:
# FIXME: SpecCluster leaks if SpecCluster.__aenter__ raises
await cluster.close()
with pytest.raises(BrokenWorkerException, match=r"Worker Broken"):
async with cluster:
pass


@pytest.mark.skipif(WINDOWS, reason="HTTP Server doesn't close out")
Expand Down
35 changes: 31 additions & 4 deletions distributed/tests/test_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,16 +764,16 @@ def test_raises_with_cause():
raise RuntimeError("foo") from ValueError("bar")

# we're trying to stick to pytest semantics
# If the exception types don't match, raise the original exception
# If the exception types don't match, raise the first exception that doesnt' match
# If the text doesn't match, raise an assert

with pytest.raises(RuntimeError):
with pytest.raises(OSError):
with raises_with_cause(RuntimeError, "exception", ValueError, "cause"):
raise RuntimeError("exception") from OSError("cause")

with pytest.raises(ValueError):
with pytest.raises(OSError):
with raises_with_cause(RuntimeError, "exception", ValueError, "cause"):
raise ValueError("exception") from ValueError("cause")
raise OSError("exception") from ValueError("cause")

with pytest.raises(AssertionError):
with raises_with_cause(RuntimeError, "exception", ValueError, "foo"):
Expand All @@ -783,6 +783,33 @@ def test_raises_with_cause():
with raises_with_cause(RuntimeError, "foo", ValueError, "cause"):
raise RuntimeError("exception") from ValueError("cause")

# There can be more than one nested cause
with raises_with_cause(
RuntimeError, "exception", ValueError, "cause1", OSError, "cause2"
):
try:
raise ValueError("cause1") from OSError("cause2")
except ValueError as e:
raise RuntimeError("exception") from e

with pytest.raises(OSError):
with raises_with_cause(
RuntimeError, "exception", ValueError, "cause1", TypeError, "cause2"
):
try:
raise ValueError("cause1") from OSError("cause2")
except ValueError as e:
raise RuntimeError("exception") from e

with pytest.raises(AssertionError):
with raises_with_cause(
RuntimeError, "exception", ValueError, "cause1", OSError, "cause2"
):
try:
raise ValueError("cause1") from OSError("no match")
except ValueError as e:
raise RuntimeError("exception") from e


@pytest.mark.slow
def test_check_thread_leak():
Expand Down
20 changes: 11 additions & 9 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import logging.config
import multiprocessing
import os
import re
import signal
import socket
import ssl
Expand Down Expand Up @@ -2100,8 +2099,10 @@ def raises_with_cause(
match: str | None,
expected_cause: type[BaseException] | tuple[type[BaseException], ...],
match_cause: str | None,
*more_causes: type[BaseException] | tuple[type[BaseException], ...] | str | None,
) -> Generator[None, None, None]:
"""Contextmanager to assert that a certain exception with cause was raised
"""Contextmanager to assert that a certain exception with cause was raised.
It can travel the causes recursively by adding more expected, match pairs at the end.
Parameters
----------
Expand All @@ -2111,13 +2112,14 @@ def raises_with_cause(
yield

exc = exc_info.value
assert exc.__cause__
if not isinstance(exc.__cause__, expected_cause):
raise exc
if match_cause:
assert re.search(
match_cause, str(exc.__cause__)
), f"Pattern ``{match_cause}`` not found in ``{exc.__cause__}``"
causes = [expected_cause, *more_causes[::2]]
match_causes = [match_cause, *more_causes[1::2]]
assert len(causes) == len(match_causes)
for expected_cause, match_cause in zip(causes, match_causes): # type: ignore
assert exc.__cause__
exc = exc.__cause__
with pytest.raises(expected_cause, match=match_cause):
raise exc


def ucx_exception_handler(loop, context):
Expand Down

0 comments on commit 9a8b380

Please sign in to comment.