diff --git a/docs/conf.py b/docs/conf.py index 16afa5272..21aedf31c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -68,6 +68,7 @@ ("py:class", "'id'"), ("py:class", "typing_extensions.Literal"), ("py:func", "int"), + ("py:class", "asyncio.locks.Event"), ] # Both the class’ and the __init__ method’s docstring are concatenated and diff --git a/examples/hdf_queue_reporting.py b/examples/hdf_queue_reporting.py index 09d3fe821..8eaab6d97 100644 --- a/examples/hdf_queue_reporting.py +++ b/examples/hdf_queue_reporting.py @@ -2,7 +2,7 @@ import sys import time -from pandablocks.asyncio import AsyncioClient +from pandablocks.asyncio import AsyncioClient, FlushMode from pandablocks.commands import Arm, Put from pandablocks.hdf import FrameProcessor, HDFWriter, create_pipeline, stop_pipeline from pandablocks.responses import EndData, EndReason, FrameData, ReadyData @@ -28,7 +28,9 @@ async def hdf_queue_reporting(): client.send(Put("SEQ1.PRESCALE", 0.5)), ) progress = 0 - async for data in client.data(scaled=False, flush_period=1): + async for data in client.data( + scaled=False, flush_period=1, flush_mode=FlushMode.PERIODIC + ): # Always pass the data down the pipeline pipeline[0].queue.put_nowait(data) if isinstance(data, ReadyData): diff --git a/src/pandablocks/asyncio.py b/src/pandablocks/asyncio.py index bb0db8e1f..cc53abdb4 100644 --- a/src/pandablocks/asyncio.py +++ b/src/pandablocks/asyncio.py @@ -2,6 +2,7 @@ import logging from asyncio.streams import StreamReader, StreamWriter from contextlib import suppress +from enum import Enum from typing import AsyncGenerator, Dict, Iterable, Optional from .commands import Command, T @@ -9,7 +10,7 @@ from .responses import Data # Define the public API of this module -__all__ = ["AsyncioClient"] +__all__ = ["AsyncioClient", "FlushMode"] class _StreamHelper: @@ -50,6 +51,21 @@ async def close(self): await writer.wait_closed() +class FlushMode(Enum): + """ + The mode which `AsyncioClient.data()` uses when flushing data frames. + """ + + #: Flush all data frames immediately. + IMMEDIATE = 0 + + #: Flush data frames periodically. + PERIODIC = 1 + + #: Flush data frames when the user sets an `asyncio.Event`. + MANUAL = 2 + + class AsyncioClient: """Asyncio implementation of a PandABlocks client. For example:: @@ -72,7 +88,7 @@ def __init__(self, host: str): async def connect(self): """Connect to the control port, and be ready to handle commands""" - await self._ctrl_stream.connect(self._host, 8888), + await self._ctrl_stream.connect(self._host, 8888) self._ctrl_task = asyncio.create_task( self._ctrl_read_forever(self._ctrl_stream.reader) @@ -132,36 +148,84 @@ async def data( scaled: bool = True, flush_period: Optional[float] = None, frame_timeout: Optional[float] = None, + flush_event: Optional[asyncio.Event] = None, + flush_mode: FlushMode = FlushMode.IMMEDIATE, ) -> AsyncGenerator[Data, None]: - """Connect to data port and yield data frames + """Connect to data port and yield data frames. Args: scaled: Whether to scale and average data frames, reduces throughput - flush_period: How often to flush partial data frames, None is on every - chunk of data from the server + flush_period: How often to flush partial data frames when ``flush_mode`` + is ``PERIODIC``. frame_timeout: If no data is received for this amount of time, raise - `asyncio.TimeoutError` + `asyncio.TimeoutError`. + flush_event: An `asyncio.Event` to manually flush. When set while + ``flush_mode`` is ``MANUAL`` a flush will be performed and the event + will be unset. + flush_mode: Which mode of the `FlushMode` values. """ + # Check input + if flush_mode == FlushMode.MANUAL: + if not flush_event: + raise ValueError( + f"flush_event cannot be {flush_event} if flush_mode is 'MANUAL'" + ) + if flush_period: + raise ValueError( + f"Unused flush_period={flush_period} " + "inputted during `MANUAL` flush mode" + ) + elif flush_mode == FlushMode.PERIODIC: + if not flush_period: + raise ValueError( + f"flush_period cannot be {flush_period} " + "if flush_mode is 'PERIODIC', use `MANUAL` " + "flush_mode to make flushing exclusively " + "manual" + ) + elif flush_mode == FlushMode.IMMEDIATE: + if flush_event: + raise ValueError( + f"Unused flush_event={flush_event} " + "inputted during `IMMEDIATE` flush mode" + ) + if flush_period: + raise ValueError( + f"Unused flush_period={flush_period} " + "inputted during `IMMEDIATE` flush mode" + ) + else: + raise ValueError( + "flush_mode must be one of 'IMMEDIATE', 'PERIODIC', 'MANUAL'" + ) + stream = _StreamHelper() connection = DataConnection() queue: asyncio.Queue[Iterable[Data]] = asyncio.Queue() + flush_every_frame = flush_mode == FlushMode.IMMEDIATE + flush_event = flush_event or asyncio.Event() def raise_timeouterror(): raise asyncio.TimeoutError(f"No data received for {frame_timeout}s") yield - async def periodic_flush(): - if flush_period is not None: - while True: - # Every flush_period seconds flush and queue data - await asyncio.sleep(flush_period) - queue.put_nowait(connection.flush()) + async def flush_loop(): + if flush_every_frame: + # If flush mode is `IMMEDIATE` flushing will be performed + # whenever possible + return + while True: + try: + await asyncio.wait_for(flush_event.wait(), flush_period) + except asyncio.TimeoutError: + pass + else: + flush_event.clear() + queue.put_nowait(connection.flush()) async def read_from_stream(): reader = stream.reader - # Should we flush every FrameData? - flush_every_frame = flush_period is None while True: try: recv = await asyncio.wait_for(reader.read(4096), frame_timeout) @@ -169,11 +233,15 @@ async def read_from_stream(): queue.put_nowait(raise_timeouterror()) break else: - queue.put_nowait(connection.receive_bytes(recv, flush_every_frame)) + queue.put_nowait( + connection.receive_bytes( + recv, flush_every_frame=flush_every_frame + ) + ) await stream.connect(self._host, 8889) await stream.write_and_drain(connection.connect(scaled)) - fut = asyncio.gather(periodic_flush(), read_from_stream()) + fut = asyncio.gather(read_from_stream(), flush_loop()) try: while True: for data in await queue.get(): diff --git a/src/pandablocks/hdf.py b/src/pandablocks/hdf.py index d220909bd..5b504c013 100644 --- a/src/pandablocks/hdf.py +++ b/src/pandablocks/hdf.py @@ -8,7 +8,7 @@ from pandablocks.commands import Arm -from .asyncio import AsyncioClient +from .asyncio import AsyncioClient, FlushMode from .connections import SAMPLES_FIELD from .responses import EndData, EndReason, FieldCapture, FrameData, ReadyData, StartData @@ -223,7 +223,9 @@ async def write_hdf_files( end_data = None pipeline = create_default_pipeline(file_names) try: - async for data in client.data(scaled=False, flush_period=flush_period): + async for data in client.data( + scaled=False, flush_period=flush_period, flush_mode=FlushMode.PERIODIC + ): pipeline[0].queue.put_nowait(data) if type(data) == EndData: end_data = data diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 9e3902a8d..cd1a994a0 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -2,8 +2,9 @@ import pytest -from pandablocks.asyncio import AsyncioClient +from pandablocks.asyncio import AsyncioClient, FlushMode from pandablocks.commands import CommandException, Get, Put +from pandablocks.responses import EndData, FrameData, ReadyData, StartData from .conftest import DummyServer @@ -30,34 +31,165 @@ async def test_asyncio_bad_put_raises(dummy_server_async): assert dummy_server_async.received == ["PCAP.thing=1"] -@pytest.mark.asyncio @pytest.mark.parametrize("disarmed", [True, False]) -@pytest.mark.parametrize("flush_period", [0.1, None]) -async def test_asyncio_data( - dummy_server_async, fast_dump, fast_dump_expected, disarmed, flush_period +async def test_asyncio_data_periodic_flushing( + dummy_server_async, fast_dump, fast_dump_expected, disarmed ): if not disarmed: # simulate getting the data without the END marker as if arm was not pressed fast_dump = (x.split(b"END")[0] for x in fast_dump) fast_dump_expected = list(fast_dump_expected)[:-1] + dummy_server_async.data = fast_dump events = [] async with AsyncioClient("localhost") as client: - async for data in client.data(frame_timeout=1, flush_period=flush_period): + async for data in client.data( + frame_timeout=1, flush_period=0.1, flush_mode=FlushMode.PERIODIC + ): events.append(data) + if len(events) == len(fast_dump_expected): break + assert fast_dump_expected == events +@pytest.mark.parametrize( + "flush_mode, flush_period", + [ + (FlushMode.PERIODIC, float("inf")), # Testing manual flush in PERIODIC mode + (FlushMode.MANUAL, None), + ], +) +async def test_asyncio_data_manual_flushing( + dummy_server_async, fast_dump, fast_dump_expected, flush_mode, flush_period +): + dummy_server_async.data = fast_dump + + # Button push event + flush_event = asyncio.Event() + + async def wait_and_press_button(data_generator): + await asyncio.sleep(0.2) + flush_event.set() + return await data_generator.__anext__() + + async with AsyncioClient("localhost") as client: + data_generator = client.data( + frame_timeout=5, + flush_event=flush_event, + flush_mode=flush_mode, + flush_period=flush_period, + ) + + assert isinstance(await data_generator.__anext__(), ReadyData) + assert isinstance(await data_generator.__anext__(), StartData) + assert isinstance(await wait_and_press_button(data_generator), FrameData) + assert not flush_event.is_set() + assert isinstance(await wait_and_press_button(data_generator), FrameData) + assert not flush_event.is_set() + assert isinstance(await wait_and_press_button(data_generator), FrameData) + assert not flush_event.is_set() + assert isinstance(await wait_and_press_button(data_generator), FrameData) + assert isinstance(await data_generator.__anext__(), EndData) + await data_generator.aclose() + + +async def test_asyncio_data_flush_every_frame( + dummy_server_async, fast_dump, fast_dump_expected +): + dummy_server_async.data = fast_dump + + events = [] + async with AsyncioClient("localhost") as client: + async for data in client.data(frame_timeout=5, flush_mode=FlushMode.IMMEDIATE): + events.append(data) + + if len(events) == len(fast_dump_expected): + break + + assert fast_dump_expected == events + + +async def test_asyncio_data_timeout_on_no_manual_press(dummy_server_async, fast_dump): + flush_event = asyncio.Event() + async with AsyncioClient("localhost") as client: + with pytest.raises(asyncio.TimeoutError, match="No data received for 3s"): + async for data in client.data( + frame_timeout=3, flush_mode=FlushMode.MANUAL, flush_event=flush_event + ): + pass + + async def test_asyncio_data_timeout(dummy_server_async, fast_dump): dummy_server_async.data = fast_dump async with AsyncioClient("localhost") as client: with pytest.raises(asyncio.TimeoutError, match="No data received for 0.1s"): - async for data in client.data(frame_timeout=0.1): + async for data in client.data( + frame_timeout=0.1, flush_mode=FlushMode.IMMEDIATE + ): "This goes forever, when it runs out of data we will get our timeout" +async def test_asyncio_data_nonexistent_flushmode(dummy_server_async, fast_dump): + dummy_server_async.data = fast_dump + async with AsyncioClient("localhost") as client: + with pytest.raises(ValueError, match="flush_mode must be one of"): + async for data in client.data(frame_timeout=0.1, flush_mode=None): + pass + + +@pytest.mark.parametrize( + "kwargs, error", + [ + ( + {"flush_mode": FlushMode.PERIODIC, "flush_period": 0}, + "flush_period cannot be 0", + ), + ( + { + "flush_mode": FlushMode.PERIODIC, + "flush_period": 0, + "flush_event": asyncio.Event(), + }, + " use `MANUAL` flush_mode to make flushing exclusively manual", + ), + ( + {"flush_mode": FlushMode.PERIODIC, "flush_period": None}, + "flush_period cannot be None", + ), + ( + {"flush_mode": FlushMode.MANUAL, "flush_event": None}, + "flush_event cannot be None", + ), + ( + { + "flush_mode": FlushMode.MANUAL, + "flush_event": asyncio.Event(), + "flush_period": 1, + }, + "Unused flush_period=1", + ), + ( + {"flush_mode": FlushMode.IMMEDIATE, "flush_period": 1}, + "Unused flush_period=1", + ), + ( + {"flush_mode": FlushMode.IMMEDIATE, "flush_event": asyncio.Event()}, + "Unused flush_event=