diff --git a/bambooflow/datapipes/__init__.py b/bambooflow/datapipes/__init__.py index 0d66f1e..fe62593 100644 --- a/bambooflow/datapipes/__init__.py +++ b/bambooflow/datapipes/__init__.py @@ -10,3 +10,4 @@ AsyncIterDataPipe, AsyncIterableWrapperAsyncIterDataPipe as AsyncIterableWrapper, ) +from bambooflow.datapipes.callable import MapperAsyncIterDataPipe as Mapper diff --git a/bambooflow/datapipes/callable.py b/bambooflow/datapipes/callable.py new file mode 100644 index 0000000..085c3d0 --- /dev/null +++ b/bambooflow/datapipes/callable.py @@ -0,0 +1,80 @@ +""" +Asynchronous Iterable DataPipes for asynchronous functions. +""" +import asyncio +from collections.abc import AsyncIterator, Callable, Coroutine +from typing import Any + +from bambooflow.datapipes.aiter import AsyncIterDataPipe + + +class MapperAsyncIterDataPipe(AsyncIterDataPipe): + """ + Applies an asynchronous function over each item from the source DataPipe. + + Parameters + ---------- + datapipe : AsyncIterDataPipe + The source asynchronous iterable-style DataPipe. + fn : Callable + Asynchronous function to be applied over each item. + + Yields + ------ + awaitable : collections.abc.Awaitable + An :py-term:`awaitable` object from the + :py-term:`asynchronous iterator `. + + Raises + ------ + ExceptionGroup + If any one of the concurrent tasks raises an :py:class:`Exception`. See + `PEP654 `_ + for general advice on how to handle exception groups. + + Example + ------- + >>> import asyncio + >>> from bambooflow.datapipes import AsyncIterableWrapper, Mapper + ... + >>> # Apply an asynchronous multiply by two function + >>> async def times_two(x) -> float: + ... await asyncio.sleep(delay=x) + ... return x * 2 + >>> dp = AsyncIterableWrapper(iterable=[0.1, 0.2, 0.3]) + >>> dp_map = Mapper(datapipe=dp, fn=times_two) + ... + >>> # Loop or iterate over the DataPipe stream + >>> it = aiter(dp_map) + >>> number = anext(it) + >>> asyncio.run(number) + 0.2 + >>> number = anext(it) + >>> asyncio.run(number) + 0.4 + >>> # Or if running in an interactive REPL with top-level `await` support + >>> number = anext(it) + >>> await number # doctest: +SKIP + 0.6 + """ + + def __init__( + self, datapipe: AsyncIterDataPipe, fn: Callable[..., Coroutine[Any, Any, Any]] + ): + super().__init__() + self._datapipe = datapipe + self._fn = fn + + async def __aiter__(self) -> AsyncIterator: + try: + async with asyncio.TaskGroup() as task_group: + tasks: list[asyncio.Task] = [ + task_group.create_task(coro=self._fn(data)) + async for data in self._datapipe + ] + except* BaseException as err: + raise ValueError(f"{err=}") from err + + for task in tasks: + result = await task + yield result diff --git a/bambooflow/tests/test_datapipes_callable.py b/bambooflow/tests/test_datapipes_callable.py new file mode 100644 index 0000000..79b9cbd --- /dev/null +++ b/bambooflow/tests/test_datapipes_callable.py @@ -0,0 +1,94 @@ +""" +Tests for callable datapipes. +""" +import asyncio +import re +import time +from collections.abc import Awaitable + +import pytest + +from bambooflow.datapipes import AsyncIterableWrapper, Mapper + + +# %% +@pytest.fixture(scope="function", name="times_two") +def fixture_times_two(): + async def times_two(x) -> int: + await asyncio.sleep(0.2) + print(f"Multiplying {x} by 2") + result = x * 2 + return result + + return times_two + + +@pytest.fixture(scope="function", name="times_three") +def fixture_times_three(): + async def times_three(x) -> int: + await asyncio.sleep(0.3) + print(f"Multiplying {x} by 3") + result = x * 3 + return result + + return times_three + + +@pytest.fixture(scope="function", name="error_four") +def fixture_error_four(): + async def error_four(x): + await asyncio.sleep(0.1) + if x == 4: + raise ValueError(f"Some problem with {x}") + + return error_four + + +# %% +async def test_mapper_concurrency(times_two, times_three): + """ + Ensure that MapperAsyncIterDataPipe works to process tasks concurrently, + such that three tasks taking 3*(0.2+0.3)=1.5 seconds in serial can be + completed in just (0.2+0.3)=0.5 seconds instead. + """ + dp = AsyncIterableWrapper(iterable=[0, 1, 2]) + dp_map2 = Mapper(datapipe=dp, fn=times_two) + dp_map3 = Mapper(datapipe=dp_map2, fn=times_three) + + i = 0 + tic = time.perf_counter() + async for num in dp_map3: + # print("Number:", num) + assert num == i * 2 * 3 + toc = time.perf_counter() + i += 1 + # print(f"Ran in {toc - tic:0.4f} seconds") + print(f"Total: {toc - tic:0.4f} seconds") + + assert toc - tic < 0.55 # Total time should be about 0.5 seconds + assert num == 12 # 2*2*3=12 + + +async def test_mapper_exception_handling(error_four): + """ + Ensure that MapperAsyncIterDataPipe can capture exceptions when one of the + tasks raises an error. + """ + dp = AsyncIterableWrapper(iterable=[3, 4, 5]) + dp_map = Mapper(datapipe=dp, fn=error_four) + + it = aiter(dp_map) + number = anext(it) + # Checek that an ExceptionGroup is already raised on first access + with pytest.raises( + ValueError, + match=re.escape( + "err=ExceptionGroup('unhandled errors in a TaskGroup', [ValueError('Some problem with 4')])" + ), + ): + await number + + # Subsequent access to iterator should raise StopAsyncIteration + number = anext(it) + with pytest.raises(StopAsyncIteration): + await number diff --git a/docs/_config.yml b/docs/_config.yml index cd52d4c..8ef90bd 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -31,6 +31,9 @@ sphinx: config: myst_all_links_external: true html_show_copyright: false + html_theme_options: + # https://sphinx-book-theme.readthedocs.io/en/stable/customize/sidebar-secondary.html + show_toc_level: 3 extlinks: py-term: - 'https://docs.python.org/3/glossary.html#term-%s' diff --git a/docs/api.md b/docs/api.md index f570802..3b2551b 100644 --- a/docs/api.md +++ b/docs/api.md @@ -11,3 +11,13 @@ .. autoclass:: bambooflow.datapipes.aiter.AsyncIterableWrapperAsyncIterDataPipe :show-inheritance: ``` + +### Mapping DataPipes + +Datapipes which apply a custom asynchronous function to elements in a DataPipe. + +```{eval-rst} +.. autoclass:: bambooflow.datapipes.Mapper +.. autoclass:: bambooflow.datapipes.callable.MapperAsyncIterDataPipe + :show-inheritance: +```