From 7e5dcc19953ca2f3009b1c10c55104df530ca840 Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Tue, 1 Aug 2023 17:08:42 +1200 Subject: [PATCH 1/4] :sparkles: MapperAsyncIterDataPipe for applying custom async functions An asynchronous iterable-style DataPipe for applying a custom asynchronous function over an asynchronous iterable! Uses asyncio.TaskGroup from Python 3.11+ to run several tasks concurrently. Included a doctest, added a new section in the API docs under 'Mapping DataPipes', and set show_toc_level to 3 to make it show in the right sidebar. --- bambooflow/datapipes/__init__.py | 1 + bambooflow/datapipes/callable.py | 69 ++++++++++++++++++++++++++++++++ docs/_config.yml | 3 ++ docs/api.md | 10 +++++ 4 files changed, 83 insertions(+) create mode 100644 bambooflow/datapipes/callable.py diff --git a/bambooflow/datapipes/__init__.py b/bambooflow/datapipes/__init__.py index 0d66f1e..fe62593 100644 --- a/bambooflow/datapipes/__init__.py +++ b/bambooflow/datapipes/__init__.py @@ -10,3 +10,4 @@ AsyncIterDataPipe, AsyncIterableWrapperAsyncIterDataPipe as AsyncIterableWrapper, ) +from bambooflow.datapipes.callable import MapperAsyncIterDataPipe as Mapper diff --git a/bambooflow/datapipes/callable.py b/bambooflow/datapipes/callable.py new file mode 100644 index 0000000..ac20e44 --- /dev/null +++ b/bambooflow/datapipes/callable.py @@ -0,0 +1,69 @@ +""" +Asynchronous Iterable DataPipes for asynchronous functions. +""" +import asyncio +from collections.abc import AsyncIterator, Callable, Coroutine +from typing import Any + +from bambooflow.datapipes.aiter import AsyncIterDataPipe + + +class MapperAsyncIterDataPipe(AsyncIterDataPipe): + """ + Applies an asynchronous function over each item from the source DataPipe. + + Parameters + ---------- + datapipe : AsyncIterDataPipe + The source asynchronous iterable-style DataPipe. + fn : Callable + Asynchronous function to be applied over each item. + + Yields + ------ + awaitable : collections.abc.Awaitable + An :py-term:`awaitable` object from the + :py-term:`asynchronous iterator `. + + Example + ------- + >>> import asyncio + >>> from bambooflow.datapipes import AsyncIterableWrapper, Mapper + ... + >>> # Apply an asynchronous multiply by two function + >>> async def times_two(x) -> float: + ... await asyncio.sleep(delay=x) + ... return x * 2 + >>> dp = AsyncIterableWrapper(iterable=[0.1, 0.2, 0.3]) + >>> dp_map = Mapper(datapipe=dp, fn=times_two) + ... + >>> # Loop or iterate over the DataPipe stream + >>> it = aiter(dp_map) + >>> number = anext(it) + >>> asyncio.run(number) + 0.2 + >>> number = anext(it) + >>> asyncio.run(number) + 0.4 + >>> # Or if running in an interactive REPL with top-level `await` support + >>> number = anext(it) + >>> await number # doctest: +SKIP + 0.6 + """ + + def __init__( + self, datapipe: AsyncIterDataPipe, fn: Callable[..., Coroutine[Any, Any, Any]] + ): + super().__init__() + self._datapipe = datapipe + self._fn = fn + + async def __aiter__(self) -> AsyncIterator: + async with asyncio.TaskGroup() as task_group: + tasks: list[asyncio.Task] = [ + task_group.create_task(coro=self._fn(data)) + async for data in self._datapipe + ] + for task in tasks: + result = await task + yield result diff --git a/docs/_config.yml b/docs/_config.yml index cd52d4c..8ef90bd 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -31,6 +31,9 @@ sphinx: config: myst_all_links_external: true html_show_copyright: false + html_theme_options: + # https://sphinx-book-theme.readthedocs.io/en/stable/customize/sidebar-secondary.html + show_toc_level: 3 extlinks: py-term: - 'https://docs.python.org/3/glossary.html#term-%s' diff --git a/docs/api.md b/docs/api.md index f570802..3b2551b 100644 --- a/docs/api.md +++ b/docs/api.md @@ -11,3 +11,13 @@ .. autoclass:: bambooflow.datapipes.aiter.AsyncIterableWrapperAsyncIterDataPipe :show-inheritance: ``` + +### Mapping DataPipes + +Datapipes which apply a custom asynchronous function to elements in a DataPipe. + +```{eval-rst} +.. autoclass:: bambooflow.datapipes.Mapper +.. autoclass:: bambooflow.datapipes.callable.MapperAsyncIterDataPipe + :show-inheritance: +``` From 6901b86cc968e033c93ec4aea064ddfa05e55779 Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Thu, 3 Aug 2023 15:29:18 +1200 Subject: [PATCH 2/4] :white_check_mark: Add unit test for MapperAsyncIterDataPipe Ensure that tasks are processed concurrently, included a timer to double check that all 3 tasks (made up of 2 sub-tasks) complete in 0.5 seconds instead of 1.5 seconds! --- bambooflow/tests/test_datapipes_callable.py | 57 +++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 bambooflow/tests/test_datapipes_callable.py diff --git a/bambooflow/tests/test_datapipes_callable.py b/bambooflow/tests/test_datapipes_callable.py new file mode 100644 index 0000000..f2c489f --- /dev/null +++ b/bambooflow/tests/test_datapipes_callable.py @@ -0,0 +1,57 @@ +""" +Tests for callable datapipes. +""" +import asyncio +import time +from collections.abc import Awaitable + +import pytest + +from bambooflow.datapipes import AsyncIterableWrapper, Mapper + + +# %% +@pytest.fixture(scope="function", name="times_two") +def fixture_times_two(): + async def times_two(x) -> int: + await asyncio.sleep(0.2) + print(f"Multiplying {x} by 2") + result = x * 2 + return result + + return times_two + + +@pytest.fixture(scope="function", name="times_three") +def fixture_times_three(): + async def times_three(x) -> int: + await asyncio.sleep(0.3) + print(f"Multiplying {x} by 3") + result = x * 3 + return result + + return times_three + + +async def test_mapper(times_two, times_three): + """ + Ensure that MapperAsyncIterDataPipe works to process tasks concurrently, + such that three tasks taking 3*(0.2+0.3)=1.5 seconds in serial can be + completed in just (0.2+0.3)=0.5 seconds instead. + """ + dp = AsyncIterableWrapper(iterable=[0, 1, 2]) + dp_map2 = Mapper(datapipe=dp, fn=times_two) + dp_map3 = Mapper(datapipe=dp_map2, fn=times_three) + + i = 0 + tic = time.perf_counter() + async for num in dp_map3: + # print("Number:", num) + assert num == i * 2 * 3 + toc = time.perf_counter() + i += 1 + # print(f"Ran in {toc - tic:0.4f} seconds") + print(f"Total: {toc - tic:0.4f} seconds") + + assert toc - tic < 0.55 # Total time should be about 0.5 seconds + assert num == 12 # 2*2*3=12 From 282bb26c8e07ff6924d16582aa91e82e9cab29d2 Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Thu, 3 Aug 2023 18:06:35 +1200 Subject: [PATCH 3/4] :goal_net: Use try-except* to catch task errors in ExceptionGroup To better handle errors from tasks in a TaskGroup, wrap the TaskGroup context manager in a try-except* clause following PEP 654 Exception Groups. Based on the nice examples from https://github.com/jrfk/talk/tree/main/EuroPython2023. Added a unit test to ensure that a ValueError raised in 1 out of 3 tasks can be nicely captured and raised to attention. --- bambooflow/datapipes/callable.py | 14 +++++--- bambooflow/tests/test_datapipes_callable.py | 39 ++++++++++++++++++++- 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/bambooflow/datapipes/callable.py b/bambooflow/datapipes/callable.py index ac20e44..aa123d2 100644 --- a/bambooflow/datapipes/callable.py +++ b/bambooflow/datapipes/callable.py @@ -59,11 +59,15 @@ def __init__( self._fn = fn async def __aiter__(self) -> AsyncIterator: - async with asyncio.TaskGroup() as task_group: - tasks: list[asyncio.Task] = [ - task_group.create_task(coro=self._fn(data)) - async for data in self._datapipe - ] + try: + async with asyncio.TaskGroup() as task_group: + tasks: list[asyncio.Task] = [ + task_group.create_task(coro=self._fn(data)) + async for data in self._datapipe + ] + except* BaseException as err: + raise ValueError(f"{err=}") from err + for task in tasks: result = await task yield result diff --git a/bambooflow/tests/test_datapipes_callable.py b/bambooflow/tests/test_datapipes_callable.py index f2c489f..79b9cbd 100644 --- a/bambooflow/tests/test_datapipes_callable.py +++ b/bambooflow/tests/test_datapipes_callable.py @@ -2,6 +2,7 @@ Tests for callable datapipes. """ import asyncio +import re import time from collections.abc import Awaitable @@ -33,7 +34,18 @@ async def times_three(x) -> int: return times_three -async def test_mapper(times_two, times_three): +@pytest.fixture(scope="function", name="error_four") +def fixture_error_four(): + async def error_four(x): + await asyncio.sleep(0.1) + if x == 4: + raise ValueError(f"Some problem with {x}") + + return error_four + + +# %% +async def test_mapper_concurrency(times_two, times_three): """ Ensure that MapperAsyncIterDataPipe works to process tasks concurrently, such that three tasks taking 3*(0.2+0.3)=1.5 seconds in serial can be @@ -55,3 +67,28 @@ async def test_mapper(times_two, times_three): assert toc - tic < 0.55 # Total time should be about 0.5 seconds assert num == 12 # 2*2*3=12 + + +async def test_mapper_exception_handling(error_four): + """ + Ensure that MapperAsyncIterDataPipe can capture exceptions when one of the + tasks raises an error. + """ + dp = AsyncIterableWrapper(iterable=[3, 4, 5]) + dp_map = Mapper(datapipe=dp, fn=error_four) + + it = aiter(dp_map) + number = anext(it) + # Checek that an ExceptionGroup is already raised on first access + with pytest.raises( + ValueError, + match=re.escape( + "err=ExceptionGroup('unhandled errors in a TaskGroup', [ValueError('Some problem with 4')])" + ), + ): + await number + + # Subsequent access to iterator should raise StopAsyncIteration + number = anext(it) + with pytest.raises(StopAsyncIteration): + await number From d2482e96ef6cabeef37233890cdbde1f92768526 Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Thu, 3 Aug 2023 18:22:45 +1200 Subject: [PATCH 4/4] :memo: Document the raising of ExceptionGroup when a task errors out Mention PEP0654 so that people know what an ExceptionGroup is, and how it could be handled. --- bambooflow/datapipes/callable.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/bambooflow/datapipes/callable.py b/bambooflow/datapipes/callable.py index aa123d2..085c3d0 100644 --- a/bambooflow/datapipes/callable.py +++ b/bambooflow/datapipes/callable.py @@ -25,6 +25,13 @@ class MapperAsyncIterDataPipe(AsyncIterDataPipe): An :py-term:`awaitable` object from the :py-term:`asynchronous iterator `. + Raises + ------ + ExceptionGroup + If any one of the concurrent tasks raises an :py:class:`Exception`. See + `PEP654 `_ + for general advice on how to handle exception groups. + Example ------- >>> import asyncio