diff --git a/iso15118/evcc/comm_session_handler.py b/iso15118/evcc/comm_session_handler.py index 079c1696..b442234e 100644 --- a/iso15118/evcc/comm_session_handler.py +++ b/iso15118/evcc/comm_session_handler.py @@ -57,7 +57,7 @@ StopNotification, UDPPacketNotification, ) -from iso15118.shared.utils import cancel_task, wait_till_finished +from iso15118.shared.utils import cancel_task, wait_for_tasks logger = logging.getLogger(__name__) @@ -290,7 +290,7 @@ async def __init__. Therefore, we need to create a separate async logger.info("Communication session handler started") - await wait_till_finished(self.list_of_tasks) + await wait_for_tasks(self.list_of_tasks) async def send_sdp(self): """ diff --git a/iso15118/secc/comm_session_handler.py b/iso15118/secc/comm_session_handler.py index 089a30de..c1782e28 100644 --- a/iso15118/secc/comm_session_handler.py +++ b/iso15118/secc/comm_session_handler.py @@ -52,7 +52,7 @@ TCPClientNotification, UDPPacketNotification, ) -from iso15118.shared.utils import cancel_task, wait_till_finished +from iso15118.shared.utils import cancel_task, wait_for_tasks logger = logging.getLogger(__name__) @@ -195,7 +195,7 @@ async def __init__. logger.info("Communication session handler started") - await wait_till_finished(self.list_of_tasks) + await wait_for_tasks(self.list_of_tasks) async def get_from_rcv_queue(self, queue: asyncio.Queue): """ diff --git a/iso15118/secc/transport/udp_server.py b/iso15118/secc/transport/udp_server.py index 6ca7d678..c7b002db 100644 --- a/iso15118/secc/transport/udp_server.py +++ b/iso15118/secc/transport/udp_server.py @@ -11,7 +11,7 @@ ReceiveTimeoutNotification, UDPPacketNotification, ) -from iso15118.shared.utils import wait_till_finished +from iso15118.shared.utils import wait_for_tasks logger = logging.getLogger(__name__) @@ -100,7 +100,7 @@ async def start(self): f"and port {SDP_SERVER_PORT}" ) tasks = [self.rcv_task()] - await wait_till_finished(tasks) + await wait_for_tasks(tasks) def connection_made(self, transport): """ diff --git a/iso15118/shared/comm_session.py b/iso15118/shared/comm_session.py index 1e03dad7..8b25beb1 100644 --- a/iso15118/shared/comm_session.py +++ b/iso15118/shared/comm_session.py @@ -50,7 +50,7 @@ from iso15118.shared.messages.v2gtp import V2GTPMessage from iso15118.shared.notifications import StopNotification from iso15118.shared.states import Pause, State, Terminate -from iso15118.shared.utils import wait_till_finished +from iso15118.shared.utils import wait_for_tasks logger = logging.getLogger(__name__) @@ -338,7 +338,7 @@ async def start(self, timeout: float): try: self._started = True - await wait_till_finished(tasks) + await wait_for_tasks(tasks) finally: self._started = False diff --git a/iso15118/shared/utils.py b/iso15118/shared/utils.py index 78d4c2b1..3737a99c 100644 --- a/iso15118/shared/utils.py +++ b/iso15118/shared/utils.py @@ -1,27 +1,10 @@ -""" -This module contains methods for managing multiple asnycio tasks that are -supposed to run concurrently. -""" - import asyncio -import json import logging -import os -from contextlib import suppress -from typing import Any, Awaitable, List +from typing import Coroutine, List logger = logging.getLogger(__name__) -def load_from_env(variable, default=None): - """Read values from the environment and try to convert values from json""" - value = os.environ.get(variable, default) - if value is not None: - with suppress(json.decoder.JSONDecodeError, TypeError): - value = json.loads(value) - return value - - async def cancel_task(task): """Cancel the task safely""" task.cancel() @@ -31,76 +14,38 @@ async def cancel_task(task): pass -async def wait_till_finished( - awaitables: List[Awaitable[Any]], finished_when=asyncio.FIRST_EXCEPTION +async def wait_for_tasks( + await_tasks: List[Coroutine], return_when=asyncio.FIRST_EXCEPTION ): - """Run the tasks until one task is finished. The condition to finish - depends on the argument 'finished_when', which directly translates - to the asyncio.wait argument 'return_when' that can assume the following - values: FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED - (For more information regarding this, please check: + """ + Method to run multiple tasks concurrently. + return_when is used directly in the asyncio.wait call and sets the + condition to cancel all running tasks and return. + The arguments for it can be: + asyncio.FIRST_COMPLETED, asyncio.FIRST_EXCEPTION or + asyncio.ALL_COMPLETED + check: https://docs.python.org/3/library/asyncio-task.html#waiting-primitives) - All unfinished tasks will be cancelled. - - It can happen that multiple tasks finished at the same time. - A MultiError is raised if at least one task finished with an exception. - This exception wraps the exception of all tasks that finished with an - exception. - - Return values of finished tasks are ignored. Use `asyncio.wait()` directly - if you need access to the return values of tasks. - - If this function turns out to be useful it might be a good fit for - `common/util` or `cc_utils`. + Similar solutions for awaiting for several tasks can be found in: + * https://python.plainenglish.io/how-to-manage-exceptions-when-waiting-on-multiple-asyncio-tasks-a5530ac10f02 # noqa: E501 + * https://stackoverflow.com/questions/63583822/asyncio-wait-on-multiple-tasks-with-timeout-and-cancellation # noqa: E501 """ tasks = [] - # As of Python 3.8 `asyncio.wait()` should be called only with - # `asyncio.Task`s. - # See: https://docs.python.org/3/library/asyncio-task.html#asyncio-example-wait-coroutine # noqa: E501 - for awaitable in awaitables: - if not isinstance(awaitable, asyncio.Task): - awaitable = asyncio.create_task(awaitable) - tasks.append(awaitable) + for task in await_tasks: + if not isinstance(task, asyncio.Task): + task = asyncio.create_task(task) + tasks.append(task) - done, pending = await asyncio.wait(tasks, return_when=finished_when) + done, pending = await asyncio.wait(tasks, return_when=return_when) for task in pending: await cancel_task(task) - errors = [] for task in done: try: task.result() - except Exception as ex: - logger.exception(ex) - errors.append(ex) - - if len(errors) == 1: - raise errors[0] - - if errors: - raise MultiError(errors) - - -class MultiError(Exception): - """Exception used to raise multiple exceptions. - - The attribute `errors` gives access to the wrapper errors. - - try: - something() - except MultiError as e: - for error in e.errors: - if isinstance(e, ZeroDivisionError): - ... - elif isinstance(e, AttributeError): - ... - - """ - - def __init__(self, errors: List[Exception]): - Exception.__init__(self) - self.errors = errors + except Exception as e: + logger.exception(e)