Skip to content

Commit

Permalink
Merge pull request #494 from procrastinate-org/exception-swallowed-493
Browse files Browse the repository at this point in the history
Stop the whole worker process when a coroutine raises
  • Loading branch information
Joachim Jablon authored Dec 5, 2021
2 parents 7230ac2 + c2588e3 commit e094777
Show file tree
Hide file tree
Showing 4 changed files with 426 additions and 83 deletions.
4 changes: 4 additions & 0 deletions procrastinate/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,7 @@ class CallerModuleUnknown(ProcrastinateException):
"""
Unable to determine the module name of the caller.
"""


class RunTaskError(ProcrastinateException):
"""One of the specified coroutines ended with an exception"""
218 changes: 184 additions & 34 deletions procrastinate/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
import asyncio
import contextlib
import datetime
import functools
import importlib
import inspect
import logging
import pathlib
import sys
import types
from typing import Any, Awaitable, Callable, Iterable, Optional, Type, TypeVar

from typing import (
Any,
Awaitable,
Callable,
Coroutine,
Iterable,
List,
Optional,
Type,
TypeVar,
)

import attr
import dateutil.parser

from procrastinate import exceptions
Expand Down Expand Up @@ -224,37 +235,6 @@ def get_full_path(obj: Any) -> str:
return f"{_get_module_name(obj)}.{obj.__name__}"


@contextlib.contextmanager
def task_context(awaitable: Awaitable, name: str):
"""
Take an awaitable, return a context manager.
On enter, launch the awaitable as a task that will execute in parallel in the
event loop. On exit, cancel the task (and log). If the task ends with an exception
log it.
A name is required for logging purposes.
"""
nice_name = name.replace("_", " ").title()

async def wrapper():
try:
logger.debug(f"Started {nice_name}", extra={"action": f"{name}_start"})
await awaitable
except asyncio.CancelledError:
logger.debug(f"Stopped {nice_name}", extra={"action": f"{name}_stop"})
raise

except Exception:
logger.exception(f"{nice_name} error", extra={"action": f"{name}_error"})

try:
task = asyncio.ensure_future(wrapper())
yield task
finally:
task.cancel()


def utcnow() -> datetime.datetime:
return datetime.datetime.now(tz=datetime.timezone.utc)

Expand Down Expand Up @@ -307,3 +287,173 @@ async def _inner_coro() -> U:
return self._return_value

return _inner_coro().__await__()


class EndMain(Exception):
pass


@attr.dataclass()
class ExceptionRecord:
task: asyncio.Task
exc: Exception


async def run_tasks(
main_coros: Iterable[Coroutine],
side_coros: Optional[Iterable[Coroutine]] = None,
graceful_stop_callback: Optional[Callable[[], Any]] = None,
):
"""
Run multiple coroutines in parallel: the main coroutines and the side
coroutines. Side coroutines are expected to run until they get cancelled.
Main corountines are expected to return at some point. By default, this
function will return None, but on certain circumstances, (see below) it can
raise a `RunTaskError`. A callback `graceful_stop_callback` will be called
if provided to ask the main coroutines to gracefully stop in case either
one of them or one of the side coroutines raise.
- If all coroutines from main_coros return and there is no exception in the
coroutines from either `main_coros` or `side_coros`:
- coroutines from `side_coros` are cancelled and awaited
- the function return None
- If any corountine from `main_coros` or `side_coros` raises an exception:
- `graceful_stop_callback` is called (the idea is that it should ask
coroutines from `main_coros` to exit gracefully)
- the function then wait for main_coros to finish, registering any
additional exception
- coroutines from `side_coros` are cancelled and awaited, registering any
additional exception
- all exceptions from coroutines in both `main_coros` and `side_coros`
are logged
- the function raises `RunTaskError`
It's not expected that coroutines from `side_coros` return. If this
happens, the function will not react in a specific way.
When a `RunTaskError` is raised because of one or more underlying
exceptions, one exception is the `__cause__` (the first main or side
coroutine that fails in the input iterables order, which will probably not
the chronologically the first one to be raised). All exceptions are logged.
"""
# Ensure all passed coros are futures (in our case, Tasks). This means that
# all the coroutines start executing now.
# `name` argument to create_task only exist on python 3.8+
if sys.version_info < (3, 8):
main_tasks = [asyncio.create_task(coro) for coro in main_coros]
side_tasks = [asyncio.create_task(coro) for coro in side_coros or []]
else:
main_tasks = [
asyncio.create_task(coro, name=coro.__name__) for coro in main_coros
]
side_tasks = [
asyncio.create_task(coro, name=coro.__name__) for coro in side_coros or []
]
for task in main_tasks + side_tasks:
name = task.get_name()
logger.debug(
f"Started {name}",
extra={
"action": f"{name}_start",
},
)

# Note that asyncio.gather() has 2 modes of execution:
# - asyncio.gather(*aws)
# Interrupts the gather at the first exception, and raises this
# exception. Otherwise, return a list containing return values for all
# coroutines
# - asyncio.gather(*aws, return_exceptions=True)
# Run every corouting until the end, return a list of either return
# values or raised exceptions (mixed).

# The _main function will always raise: either an exception if one happens
# in the main tasks, or EndMain if every coroutine returned
async def _main():
await asyncio.gather(*main_tasks)
raise EndMain

exception_records: List[ExceptionRecord] = []
try:
# side_tasks supposedly never finish, and _main always raises.
# Consequently, it's theoretically impossible to leave this try block
# without going through one of the except branches.
await asyncio.gather(_main(), *side_tasks)
except EndMain:
pass
except Exception as exc:
logger.error(
"Main coroutine error, initiating remaining coroutines stop. "
f"Cause: {exc!r}",
extra={
"action": "run_tasks_stop_requested",
},
)
if graceful_stop_callback:
graceful_stop_callback()

# Even if we asked the main tasks to stop, we still need to wait for
# them to actually stop. This may take some time. At this point, any
# additional exception will be registered but will not impact execution
# flow.
results = await asyncio.gather(*main_tasks, return_exceptions=True)
for task, result in zip(main_tasks, results):
if isinstance(result, Exception):
exception_records.append(
ExceptionRecord(
task=task,
exc=result,
)
)
else:
if sys.version_info >= (3, 8):
name = task.get_name()
logger.debug(
f"{name} finished execution",
extra={
"action": f"{name}_stop",
},
)

for task in side_tasks:
task.cancel()
try:
# task.cancel() says that the next time a task is executed, it will
# raise, but we need to give control back to the task for it to
# actually recieve the exception.
await task
except asyncio.CancelledError:
if sys.version_info >= (3, 8):
name = task.get_name()
logger.debug(
f"Stopped {name}",
extra={
"action": f"{name}_stop",
},
)
except Exception as exc:
exception_records.append(
ExceptionRecord(
task=task,
exc=exc,
)
)

for exception_record in exception_records:
if sys.version_info < (3, 8):
message = f"{exception_record.exc!r}"
action = "run_tasks_error"
else:
name = exception_record.task.get_name()
message = f"{name} error: {exception_record.exc!r}"
action = f"{name}_error"
logger.exception(
message,
extra={
"action": action,
},
)

if exception_records:
raise exceptions.RunTaskError from exception_records[0].exc
33 changes: 14 additions & 19 deletions procrastinate/worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import contextlib
import logging
import time
from enum import Enum
Expand Down Expand Up @@ -88,19 +87,15 @@ def context_for_worker(

return context

def listener(self):
async def listener(self):
assert self.notify_event
return utils.task_context(
awaitable=self.job_manager.listen_for_jobs(
event=self.notify_event, queues=self.queues
),
name="listener",
return await self.job_manager.listen_for_jobs(
event=self.notify_event,
queues=self.queues,
)

def periodic_deferrer(self):
return utils.task_context(
awaitable=self.app.periodic_deferrer.worker(), name="periodic_deferrer"
)
async def periodic_deferrer(self):
return await self.app.periodic_deferrer.worker()

async def run(self) -> None:
self.notify_event = asyncio.Event()
Expand All @@ -113,18 +108,18 @@ async def run(self) -> None:
),
)

with contextlib.ExitStack() as stack:
with signals.on_stop(self.stop):
side_coros = [self.periodic_deferrer()]
if self.wait and self.listen_notify:
stack.enter_context(self.listener())
side_coros.append(self.listener())

stack.enter_context(self.periodic_deferrer())
stack.enter_context(signals.on_stop(self.stop))

await asyncio.gather(
*(
await utils.run_tasks(
main_coros=(
self.single_worker(worker_id=worker_id)
for worker_id in range(self.concurrency)
)
),
side_coros=side_coros,
graceful_stop_callback=self.stop,
)

self.logger.info(
Expand Down
Loading

0 comments on commit e094777

Please sign in to comment.