diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 034e0bf2953..9e9330e45e9 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2569,6 +2569,24 @@ def __call__(self, *args, **kwargs): threadpool.shutdown() +@gen_cluster(client=True) +async def test_run_spec_deserialize_fail(c, s, a, b): + class F: + def __call__(self): + pass + + def __reduce__(self): + return lambda: 1 / 0, () + + with captured_logger("distributed.worker") as logger: + fut = c.submit(F()) + assert isinstance(await fut.exception(), ZeroDivisionError) + + logvalue = logger.getvalue() + assert "Could not deserialize task" in logvalue + assert "return lambda: 1 / 0, ()" in logvalue + + @gen_cluster(client=True) async def test_gather_dep_exception_one_task(c, s, a, b): """Ensure an exception in a single task does not tear down an entire batch of gather_dep diff --git a/distributed/worker.py b/distributed/worker.py index 437a5b20773..299e1dae9e0 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3645,36 +3645,22 @@ def actor_attribute(self, actor=None, attribute=None) -> dict[str, Any]: return {"status": "error", "exception": to_serialize(ex)} async def _maybe_deserialize_task( - self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Callable, tuple, dict[str, Any]] | None: - if ts.run_spec is None: - return None - try: - start = time() - # Offload deserializing large tasks - if sizeof(ts.run_spec) > OFFLOAD_THRESHOLD: - function, args, kwargs = await offload(_deserialize, *ts.run_spec) - else: - function, args, kwargs = _deserialize(*ts.run_spec) - stop = time() + self, ts: TaskState + ) -> tuple[Callable, tuple, dict[str, Any]]: + assert ts.run_spec is not None + start = time() + # Offload deserializing large tasks + if sizeof(ts.run_spec) > OFFLOAD_THRESHOLD: + function, args, kwargs = await offload(_deserialize, *ts.run_spec) + else: + function, args, kwargs = _deserialize(*ts.run_spec) + stop = time() - if stop - start > 0.010: - ts.startstops.append( - {"action": "deserialize", "start": start, "stop": stop} - ) - return function, args, kwargs - except Exception as e: - logger.error("Could not deserialize task", exc_info=True) - self.log.append((ts.key, "deserialize-error", stimulus_id, time())) - emsg = error_message(e) - del emsg["status"] # type: ignore - self.transition( - ts, - "error", - **emsg, - stimulus_id=stimulus_id, + if stop - start > 0.010: + ts.startstops.append( + {"action": "deserialize", "start": start, "stop": stop} ) - raise + return function, args, kwargs def _ensure_computing(self) -> RecsInstrs: if self.status != Status.running: @@ -3747,16 +3733,22 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No ) return AlreadyCancelledEvent(key=ts.key, stimulus_id=stimulus_id) + try: + function, args, kwargs = await self._maybe_deserialize_task(ts) + except Exception as exc: + logger.error("Could not deserialize task %s", key, exc_info=True) + return ExecuteFailureEvent.from_exception( + exc, + key=key, + stimulus_id=f"run-spec-deserialize-failed-{time()}", + ) + try: if self.validate: assert not ts.waiting_for_data assert ts.state == "executing" assert ts.run_spec is not None - function, args, kwargs = await self._maybe_deserialize_task( # type: ignore - ts, stimulus_id=stimulus_id - ) - args2, kwargs2 = self._prepare_args_for_execution(ts, args, kwargs) try: @@ -3837,29 +3829,20 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No convert_kwargs_to_str(kwargs2, max_len=1000), result["exception_text"], ) - return ExecuteFailureEvent( + return ExecuteFailureEvent.from_exception( + result, key=key, start=result["start"], stop=result["stop"], - exception=result["exception"], - traceback=result["traceback"], - exception_text=result["exception_text"], - traceback_text=result["traceback_text"], stimulus_id=f"task-erred-{time()}", ) except Exception as exc: logger.error("Exception during execution of task %s.", key, exc_info=True) - msg = error_message(exc) - return ExecuteFailureEvent( + return ExecuteFailureEvent.from_exception( + exc, key=key, - start=None, - stop=None, - exception=msg["exception"], - traceback=msg["traceback"], - exception_text=msg["exception_text"], - traceback_text=msg["traceback_text"], - stimulus_id=f"task-erred-{time()}", + stimulus_id=f"execute-unknown-error-{time()}", ) @functools.singledispatchmethod diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index f5fa39c0802..060028aa24b 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -12,6 +12,7 @@ import dask from dask.utils import parse_bytes +from distributed.core import ErrorMessage, error_message from distributed.protocol.serialize import Serialize from distributed.utils import recursive_to_dict @@ -458,7 +459,6 @@ class ExecuteSuccessEvent(StateMachineEvent): stop: float nbytes: int type: type | None - stimulus_id: str __slots__ = tuple(__annotations__) # type: ignore def to_loggable(self, *, handled: float) -> StateMachineEvent: @@ -481,13 +481,38 @@ class ExecuteFailureEvent(StateMachineEvent): traceback: Serialize | None exception_text: str traceback_text: str - stimulus_id: str __slots__ = tuple(__annotations__) # type: ignore def _after_from_dict(self) -> None: self.exception = Serialize(Exception()) self.traceback = None + @classmethod + def from_exception( + cls, + err_or_msg: BaseException | ErrorMessage, + *, + key: str, + start: float | None = None, + stop: float | None = None, + stimulus_id: str, + ) -> ExecuteFailureEvent: + if isinstance(err_or_msg, dict): + msg = err_or_msg + else: + msg = error_message(err_or_msg) + + return cls( + key=key, + start=start, + stop=stop, + exception=msg["exception"], + traceback=msg["traceback"], + exception_text=msg["exception_text"], + traceback_text=msg["traceback_text"], + stimulus_id=stimulus_id, + ) + @dataclass class CancelComputeEvent(StateMachineEvent):