Skip to content

Commit

Permalink
WSMR/deserialize_task (#6411)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored May 23, 2022
1 parent 97a7eb6 commit d84485b
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 48 deletions.
18 changes: 18 additions & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2569,6 +2569,24 @@ def __call__(self, *args, **kwargs):
threadpool.shutdown()


@gen_cluster(client=True)
async def test_run_spec_deserialize_fail(c, s, a, b):
class F:
def __call__(self):
pass

def __reduce__(self):
return lambda: 1 / 0, ()

with captured_logger("distributed.worker") as logger:
fut = c.submit(F())
assert isinstance(await fut.exception(), ZeroDivisionError)

logvalue = logger.getvalue()
assert "Could not deserialize task" in logvalue
assert "return lambda: 1 / 0, ()" in logvalue


@gen_cluster(client=True)
async def test_gather_dep_exception_one_task(c, s, a, b):
"""Ensure an exception in a single task does not tear down an entire batch of gather_dep
Expand Down
75 changes: 29 additions & 46 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3645,36 +3645,22 @@ def actor_attribute(self, actor=None, attribute=None) -> dict[str, Any]:
return {"status": "error", "exception": to_serialize(ex)}

async def _maybe_deserialize_task(
self, ts: TaskState, *, stimulus_id: str
) -> tuple[Callable, tuple, dict[str, Any]] | None:
if ts.run_spec is None:
return None
try:
start = time()
# Offload deserializing large tasks
if sizeof(ts.run_spec) > OFFLOAD_THRESHOLD:
function, args, kwargs = await offload(_deserialize, *ts.run_spec)
else:
function, args, kwargs = _deserialize(*ts.run_spec)
stop = time()
self, ts: TaskState
) -> tuple[Callable, tuple, dict[str, Any]]:
assert ts.run_spec is not None
start = time()
# Offload deserializing large tasks
if sizeof(ts.run_spec) > OFFLOAD_THRESHOLD:
function, args, kwargs = await offload(_deserialize, *ts.run_spec)
else:
function, args, kwargs = _deserialize(*ts.run_spec)
stop = time()

if stop - start > 0.010:
ts.startstops.append(
{"action": "deserialize", "start": start, "stop": stop}
)
return function, args, kwargs
except Exception as e:
logger.error("Could not deserialize task", exc_info=True)
self.log.append((ts.key, "deserialize-error", stimulus_id, time()))
emsg = error_message(e)
del emsg["status"] # type: ignore
self.transition(
ts,
"error",
**emsg,
stimulus_id=stimulus_id,
if stop - start > 0.010:
ts.startstops.append(
{"action": "deserialize", "start": start, "stop": stop}
)
raise
return function, args, kwargs

def _ensure_computing(self) -> RecsInstrs:
if self.status != Status.running:
Expand Down Expand Up @@ -3747,16 +3733,22 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No
)
return AlreadyCancelledEvent(key=ts.key, stimulus_id=stimulus_id)

try:
function, args, kwargs = await self._maybe_deserialize_task(ts)
except Exception as exc:
logger.error("Could not deserialize task %s", key, exc_info=True)
return ExecuteFailureEvent.from_exception(
exc,
key=key,
stimulus_id=f"run-spec-deserialize-failed-{time()}",
)

try:
if self.validate:
assert not ts.waiting_for_data
assert ts.state == "executing"
assert ts.run_spec is not None

function, args, kwargs = await self._maybe_deserialize_task( # type: ignore
ts, stimulus_id=stimulus_id
)

args2, kwargs2 = self._prepare_args_for_execution(ts, args, kwargs)

try:
Expand Down Expand Up @@ -3837,29 +3829,20 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No
convert_kwargs_to_str(kwargs2, max_len=1000),
result["exception_text"],
)
return ExecuteFailureEvent(
return ExecuteFailureEvent.from_exception(
result,
key=key,
start=result["start"],
stop=result["stop"],
exception=result["exception"],
traceback=result["traceback"],
exception_text=result["exception_text"],
traceback_text=result["traceback_text"],
stimulus_id=f"task-erred-{time()}",
)

except Exception as exc:
logger.error("Exception during execution of task %s.", key, exc_info=True)
msg = error_message(exc)
return ExecuteFailureEvent(
return ExecuteFailureEvent.from_exception(
exc,
key=key,
start=None,
stop=None,
exception=msg["exception"],
traceback=msg["traceback"],
exception_text=msg["exception_text"],
traceback_text=msg["traceback_text"],
stimulus_id=f"task-erred-{time()}",
stimulus_id=f"execute-unknown-error-{time()}",
)

@functools.singledispatchmethod
Expand Down
29 changes: 27 additions & 2 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import dask
from dask.utils import parse_bytes

from distributed.core import ErrorMessage, error_message
from distributed.protocol.serialize import Serialize
from distributed.utils import recursive_to_dict

Expand Down Expand Up @@ -458,7 +459,6 @@ class ExecuteSuccessEvent(StateMachineEvent):
stop: float
nbytes: int
type: type | None
stimulus_id: str
__slots__ = tuple(__annotations__) # type: ignore

def to_loggable(self, *, handled: float) -> StateMachineEvent:
Expand All @@ -481,13 +481,38 @@ class ExecuteFailureEvent(StateMachineEvent):
traceback: Serialize | None
exception_text: str
traceback_text: str
stimulus_id: str
__slots__ = tuple(__annotations__) # type: ignore

def _after_from_dict(self) -> None:
self.exception = Serialize(Exception())
self.traceback = None

@classmethod
def from_exception(
cls,
err_or_msg: BaseException | ErrorMessage,
*,
key: str,
start: float | None = None,
stop: float | None = None,
stimulus_id: str,
) -> ExecuteFailureEvent:
if isinstance(err_or_msg, dict):
msg = err_or_msg
else:
msg = error_message(err_or_msg)

return cls(
key=key,
start=start,
stop=stop,
exception=msg["exception"],
traceback=msg["traceback"],
exception_text=msg["exception_text"],
traceback_text=msg["traceback_text"],
stimulus_id=stimulus_id,
)


@dataclass
class CancelComputeEvent(StateMachineEvent):
Expand Down

0 comments on commit d84485b

Please sign in to comment.