Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ MapperAsyncIterDataPipe for applying custom async functions #9

Merged
merged 4 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bambooflow/datapipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
AsyncIterDataPipe,
AsyncIterableWrapperAsyncIterDataPipe as AsyncIterableWrapper,
)
from bambooflow.datapipes.callable import MapperAsyncIterDataPipe as Mapper
80 changes: 80 additions & 0 deletions bambooflow/datapipes/callable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
Asynchronous Iterable DataPipes for asynchronous functions.
"""
import asyncio
from collections.abc import AsyncIterator, Callable, Coroutine
from typing import Any

from bambooflow.datapipes.aiter import AsyncIterDataPipe


class MapperAsyncIterDataPipe(AsyncIterDataPipe):
"""
Applies an asynchronous function over each item from the source DataPipe.

Parameters
----------
datapipe : AsyncIterDataPipe
The source asynchronous iterable-style DataPipe.
fn : Callable
Asynchronous function to be applied over each item.

Yields
------
awaitable : collections.abc.Awaitable
An :py-term:`awaitable` object from the
:py-term:`asynchronous iterator <asynchronous-iterator>`.

Raises
------
ExceptionGroup
If any one of the concurrent tasks raises an :py:class:`Exception`. See
`PEP654 <https://peps.python.org/pep-0654/#handling-exception-groups>`_
for general advice on how to handle exception groups.

Example
-------
>>> import asyncio
>>> from bambooflow.datapipes import AsyncIterableWrapper, Mapper
...
>>> # Apply an asynchronous multiply by two function
>>> async def times_two(x) -> float:
... await asyncio.sleep(delay=x)
... return x * 2
>>> dp = AsyncIterableWrapper(iterable=[0.1, 0.2, 0.3])
>>> dp_map = Mapper(datapipe=dp, fn=times_two)
...
>>> # Loop or iterate over the DataPipe stream
>>> it = aiter(dp_map)
>>> number = anext(it)
>>> asyncio.run(number)
0.2
>>> number = anext(it)
>>> asyncio.run(number)
0.4
>>> # Or if running in an interactive REPL with top-level `await` support
>>> number = anext(it)
>>> await number # doctest: +SKIP
0.6
"""

def __init__(
self, datapipe: AsyncIterDataPipe, fn: Callable[..., Coroutine[Any, Any, Any]]
):
super().__init__()
self._datapipe = datapipe
self._fn = fn

async def __aiter__(self) -> AsyncIterator:
try:
async with asyncio.TaskGroup() as task_group:
tasks: list[asyncio.Task] = [
task_group.create_task(coro=self._fn(data))
async for data in self._datapipe
]
except* BaseException as err:
raise ValueError(f"{err=}") from err

for task in tasks:
result = await task
yield result
94 changes: 94 additions & 0 deletions bambooflow/tests/test_datapipes_callable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
Tests for callable datapipes.
"""
import asyncio
import re
import time
from collections.abc import Awaitable

import pytest

from bambooflow.datapipes import AsyncIterableWrapper, Mapper


# %%
@pytest.fixture(scope="function", name="times_two")
def fixture_times_two():
async def times_two(x) -> int:
await asyncio.sleep(0.2)
print(f"Multiplying {x} by 2")
result = x * 2
return result

return times_two


@pytest.fixture(scope="function", name="times_three")
def fixture_times_three():
async def times_three(x) -> int:
await asyncio.sleep(0.3)
print(f"Multiplying {x} by 3")
result = x * 3
return result

return times_three


@pytest.fixture(scope="function", name="error_four")
def fixture_error_four():
async def error_four(x):
await asyncio.sleep(0.1)
if x == 4:
raise ValueError(f"Some problem with {x}")

return error_four


# %%
async def test_mapper_concurrency(times_two, times_three):
"""
Ensure that MapperAsyncIterDataPipe works to process tasks concurrently,
such that three tasks taking 3*(0.2+0.3)=1.5 seconds in serial can be
completed in just (0.2+0.3)=0.5 seconds instead.
"""
dp = AsyncIterableWrapper(iterable=[0, 1, 2])
dp_map2 = Mapper(datapipe=dp, fn=times_two)
dp_map3 = Mapper(datapipe=dp_map2, fn=times_three)

i = 0
tic = time.perf_counter()
async for num in dp_map3:
# print("Number:", num)
assert num == i * 2 * 3
toc = time.perf_counter()
i += 1
# print(f"Ran in {toc - tic:0.4f} seconds")
print(f"Total: {toc - tic:0.4f} seconds")

assert toc - tic < 0.55 # Total time should be about 0.5 seconds
assert num == 12 # 2*2*3=12


async def test_mapper_exception_handling(error_four):
"""
Ensure that MapperAsyncIterDataPipe can capture exceptions when one of the
tasks raises an error.
"""
dp = AsyncIterableWrapper(iterable=[3, 4, 5])
dp_map = Mapper(datapipe=dp, fn=error_four)

it = aiter(dp_map)
number = anext(it)
# Checek that an ExceptionGroup is already raised on first access
with pytest.raises(
ValueError,
match=re.escape(
"err=ExceptionGroup('unhandled errors in a TaskGroup', [ValueError('Some problem with 4')])"
),
):
await number

# Subsequent access to iterator should raise StopAsyncIteration
number = anext(it)
with pytest.raises(StopAsyncIteration):
await number
3 changes: 3 additions & 0 deletions docs/_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ sphinx:
config:
myst_all_links_external: true
html_show_copyright: false
html_theme_options:
# https://sphinx-book-theme.readthedocs.io/en/stable/customize/sidebar-secondary.html
show_toc_level: 3
extlinks:
py-term:
- 'https://docs.python.org/3/glossary.html#term-%s'
Expand Down
10 changes: 10 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,13 @@
.. autoclass:: bambooflow.datapipes.aiter.AsyncIterableWrapperAsyncIterDataPipe
:show-inheritance:
```

### Mapping DataPipes

Datapipes which apply a custom asynchronous function to elements in a DataPipe.

```{eval-rst}
.. autoclass:: bambooflow.datapipes.Mapper
.. autoclass:: bambooflow.datapipes.callable.MapperAsyncIterDataPipe
:show-inheritance:
```