Skip to content

Commit

Permalink
🐛♻️ Better context clenup (#2586)
Browse files Browse the repository at this point in the history
* refactor & add cleanup

* attached events to stop

* docs and pylint

* adding new test and refactored exiting

* remove dunsused

* renamed function to avoid confusion

Co-authored-by: Andrei Neagu <[email protected]>
  • Loading branch information
GitHK and Andrei Neagu authored Oct 20, 2021
1 parent aff0b8e commit c215bf9
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 158 deletions.
55 changes: 45 additions & 10 deletions packages/service-library/src/servicelib/async_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,33 @@
import asyncio
import logging
from collections import deque
from functools import wraps
from typing import Dict, List
from typing import Dict, List, Optional

import attr

logger = logging.getLogger(__name__)


@attr.s(auto_attribs=True)
class Context:
in_queue: asyncio.Queue
out_queue: asyncio.Queue
initialized: bool
task: Optional[asyncio.Task] = None


_sequential_jobs_contexts: Dict[str, Context] = {}

sequential_jobs_contexts = {}

async def stop_sequential_workers() -> None:
"""Singlas all workers to close thus avoiding errors on shutdown"""
for context in _sequential_jobs_contexts.values():
await context.in_queue.put(None)
if context.task is not None:
await context.task
_sequential_jobs_contexts.clear()
logger.info("All run_sequentially_in_context pending workers stopped")


def run_sequentially_in_context(target_args: List[str] = None):
Expand All @@ -35,11 +49,13 @@ async def func(param1, param2, param3):
functions = [
func(1, "something", 3),
func(1, "else", 3),
func(1, "argument.attribute", 3),
func(1, "here", 3),
]
await asyncio.gather(*functions)
note the special "argument.attribute", which will use the attribute of argument to create the context.
The following calls will run in parallel, because they have different contexts:
functions = [
Expand All @@ -62,24 +78,34 @@ def get_context(args, kwargs: Dict) -> Context:

key_parts = deque()
for arg in target_args:
if arg not in search_args:
sub_args = arg.split(".")
main_arg = sub_args[0]
if main_arg not in search_args:
message = (
f"Expected '{arg}' in '{decorated_function.__name__}'"
f"Expected '{main_arg}' in '{decorated_function.__name__}'"
f" arguments. Got '{search_args}'"
)
raise ValueError(message)
key_parts.append(search_args[arg])
context_key = search_args[main_arg]
for attribute in sub_args[1:]:
potential_key = getattr(context_key, attribute)
if not potential_key:
message = f"Expected '{attribute}' attribute in '{context_key.__name__}' arguments."
raise ValueError(message)
context_key = potential_key

key_parts.append(f"{decorated_function.__name__}_{context_key}")

key = ":".join(map(str, key_parts))

if key not in sequential_jobs_contexts:
sequential_jobs_contexts[key] = Context(
if key not in _sequential_jobs_contexts:
_sequential_jobs_contexts[key] = Context(
in_queue=asyncio.Queue(),
out_queue=asyncio.Queue(),
initialized=False,
)

return sequential_jobs_contexts[key]
return _sequential_jobs_contexts[key]

@wraps(decorated_function)
async def wrapper(*args, **kwargs):
Expand All @@ -92,13 +118,22 @@ async def worker(in_q: asyncio.Queue, out_q: asyncio.Queue):
while True:
awaitable = await in_q.get()
in_q.task_done()
# check if requested to shutdown
if awaitable is None:
break
try:
result = await awaitable
except Exception as e: # pylint: disable=broad-except
result = e
await out_q.put(result)

asyncio.get_event_loop().create_task(
logging.info(
"Closed worker for @run_sequentially_in_context applied to '%s' with target_args=%s",
decorated_function.__name__,
target_args,
)

context.task = asyncio.create_task(
worker(context.in_queue, context.out_queue)
)

Expand Down
109 changes: 89 additions & 20 deletions packages/service-library/tests/test_async_utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,47 @@
# pylint: disable=redefined-outer-name
# pylint: disable=unused-argument

import asyncio
import copy
import random
from collections import deque
from dataclasses import dataclass
from time import time
from typing import Any, Dict, List
from typing import Any, AsyncIterable, Dict, List, Optional

import pytest
from servicelib.async_utils import run_sequentially_in_context, sequential_jobs_contexts
from servicelib.async_utils import (
_sequential_jobs_contexts,
run_sequentially_in_context,
stop_sequential_workers,
)

RETRIES = 10
DIFFERENT_CONTEXTS_COUNT = 10

@pytest.fixture(autouse=True)
def ensure_run_in_sequence_context_is_empty():
# NOTE: since the contexts variable is initialized at import time, when several test run
# the import happens only once and is rendered invalid, therefore explicit clearance is necessary
sequential_jobs_contexts.clear()

@pytest.fixture
async def ensure_run_in_sequence_context_is_empty(loop) -> AsyncIterable[None]:
yield
# NOTE
# required when shutting down the application or ending tests
# otherwise errors will occur when closing the loop
await stop_sequential_workers()


@pytest.fixture
def payload() -> str:
return "some string payload"


@pytest.fixture
def expected_param_name() -> str:
return "expected_param_name"


@pytest.fixture
def sleep_duration() -> float:
return 0.01


class LockedStore:
Expand All @@ -34,12 +60,14 @@ async def get_all(self) -> List[Any]:
return list(self._queue)


async def test_context_aware_dispatch() -> None:
async def test_context_aware_dispatch(
sleep_duration: float,
ensure_run_in_sequence_context_is_empty: None,
) -> None:
@run_sequentially_in_context(target_args=["c1", "c2", "c3"])
async def orderly(c1: Any, c2: Any, c3: Any, control: Any) -> None:
_ = (c1, c2, c3)
sleep_interval = random.uniform(0, 0.01)
await asyncio.sleep(sleep_interval)
await asyncio.sleep(sleep_duration)

context = dict(c1=c1, c2=c2, c3=c3)
await locked_stores[make_key_from_context(context)].push(control)
Expand Down Expand Up @@ -81,12 +109,14 @@ def make_context():
assert list(expected_outcomes[key]) == await locked_stores[key].get_all()


async def test_context_aware_function_sometimes_fails() -> None:
async def test_context_aware_function_sometimes_fails(
ensure_run_in_sequence_context_is_empty: None,
) -> None:
class DidFailException(Exception):
pass

@run_sequentially_in_context(target_args=["will_fail"])
async def sometimes_failing(will_fail: bool) -> None:
async def sometimes_failing(will_fail: bool) -> bool:
if will_fail:
raise DidFailException("I was instructed to fail")
return True
Expand All @@ -101,8 +131,10 @@ async def sometimes_failing(will_fail: bool) -> None:
assert await sometimes_failing(raise_error) is True


async def test_context_aware_wrong_target_args_name() -> None:
expected_param_name = "wrong_parameter"
async def test_context_aware_wrong_target_args_name(
expected_param_name: str,
ensure_run_in_sequence_context_is_empty: None, # pylint: disable=unused-argument
) -> None:

# pylint: disable=unused-argument
@run_sequentially_in_context(target_args=[expected_param_name])
Expand All @@ -119,15 +151,17 @@ async def target_function(the_param: Any) -> None:
assert str(excinfo.value).startswith(message) is True


async def test_context_aware_measure_parallelism() -> None:
async def test_context_aware_measure_parallelism(
sleep_duration: float,
ensure_run_in_sequence_context_is_empty: None,
) -> None:
# expected duration 1 second
@run_sequentially_in_context(target_args=["control"])
async def sleep_for(sleep_interval: float, control: Any) -> Any:
await asyncio.sleep(sleep_interval)
return control

control_sequence = list(range(1000))
sleep_duration = 0.5
control_sequence = list(range(RETRIES))
functions = [sleep_for(sleep_duration, x) for x in control_sequence]

start = time()
Expand All @@ -138,15 +172,17 @@ async def sleep_for(sleep_interval: float, control: Any) -> Any:
assert control_sequence == result


async def test_context_aware_measure_serialization() -> None:
async def test_context_aware_measure_serialization(
sleep_duration: float,
ensure_run_in_sequence_context_is_empty: None,
) -> None:
# expected duration 1 second
@run_sequentially_in_context(target_args=["control"])
async def sleep_for(sleep_interval: float, control: Any) -> Any:
await asyncio.sleep(sleep_interval)
return control

control_sequence = [1 for _ in range(10)]
sleep_duration = 0.1
control_sequence = [1 for _ in range(RETRIES)]
functions = [sleep_for(sleep_duration, x) for x in control_sequence]

start = time()
Expand All @@ -156,3 +192,36 @@ async def sleep_for(sleep_interval: float, control: Any) -> Any:
minimum_timelapse = (sleep_duration) * len(control_sequence)
assert elapsed > minimum_timelapse
assert control_sequence == result


async def test_nested_object_attribute(
payload: str,
ensure_run_in_sequence_context_is_empty: None,
) -> None:
@dataclass
class ObjectWithPropos:
attr1: str = payload

@run_sequentially_in_context(target_args=["object_with_props.attr1"])
async def test_attribute(
object_with_props: ObjectWithPropos, other_attr: Optional[int] = None
) -> str:
return object_with_props.attr1

for _ in range(RETRIES):
assert payload == await test_attribute(ObjectWithPropos())


async def test_different_contexts(
payload: str,
ensure_run_in_sequence_context_is_empty: None,
) -> None:
@run_sequentially_in_context(target_args=["context_param"])
async def test_multiple_context_calls(context_param: int) -> int:
return context_param

for _ in range(RETRIES):
for i in range(DIFFERENT_CONTEXTS_COUNT):
assert i == await test_multiple_context_calls(i)

assert len(_sequential_jobs_contexts) == RETRIES
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastapi import APIRouter, Depends, HTTPException
from models_library.projects import ProjectAtDB, ProjectID
from models_library.projects_state import RunningState
from servicelib.async_utils import run_sequentially_in_context
from starlette import status
from starlette.requests import Request
from tenacity import (
Expand All @@ -32,7 +33,6 @@
from ...modules.db.repositories.comp_tasks import CompTasksRepository
from ...modules.db.repositories.projects import ProjectsRepository
from ...modules.director_v0 import DirectorV0Client
from ...utils.async_utils import run_sequentially_in_context
from ...utils.computations import (
get_pipeline_state_from_task_states,
is_pipeline_running,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from servicelib.async_utils import stop_sequential_workers

from ..meta import PROJECT_NAME, __version__

#
Expand All @@ -16,10 +18,11 @@
)


def on_startup() -> None:
async def on_startup() -> None:
print(WELCOME_MSG, flush=True)


def on_shutdown() -> None:
async def on_shutdown() -> None:
await stop_sequential_workers()
msg = PROJECT_NAME + f" v{__version__} SHUT DOWN"
print(f"{msg:=^100}", flush=True)
Loading

0 comments on commit c215bf9

Please sign in to comment.