Skip to content

Commit

Permalink
Merge pull request #64 from evalott100/add_manual_flush
Browse files Browse the repository at this point in the history
Add manual flush
  • Loading branch information
evalott100 authored Nov 21, 2023
2 parents 7117d68 + 3693df5 commit c6ab247
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 27 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
("py:class", "'id'"),
("py:class", "typing_extensions.Literal"),
("py:func", "int"),
("py:class", "asyncio.locks.Event"),
]

# Both the class’ and the __init__ method’s docstring are concatenated and
Expand Down
6 changes: 4 additions & 2 deletions examples/hdf_queue_reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import time

from pandablocks.asyncio import AsyncioClient
from pandablocks.asyncio import AsyncioClient, FlushMode
from pandablocks.commands import Arm, Put
from pandablocks.hdf import FrameProcessor, HDFWriter, create_pipeline, stop_pipeline
from pandablocks.responses import EndData, EndReason, FrameData, ReadyData
Expand All @@ -28,7 +28,9 @@ async def hdf_queue_reporting():
client.send(Put("SEQ1.PRESCALE", 0.5)),
)
progress = 0
async for data in client.data(scaled=False, flush_period=1):
async for data in client.data(
scaled=False, flush_period=1, flush_mode=FlushMode.PERIODIC
):
# Always pass the data down the pipeline
pipeline[0].queue.put_nowait(data)
if isinstance(data, ReadyData):
Expand Down
100 changes: 84 additions & 16 deletions src/pandablocks/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import logging
from asyncio.streams import StreamReader, StreamWriter
from contextlib import suppress
from enum import Enum
from typing import AsyncGenerator, Dict, Iterable, Optional

from .commands import Command, T
from .connections import ControlConnection, DataConnection
from .responses import Data

# Define the public API of this module
__all__ = ["AsyncioClient"]
__all__ = ["AsyncioClient", "FlushMode"]


class _StreamHelper:
Expand Down Expand Up @@ -50,6 +51,21 @@ async def close(self):
await writer.wait_closed()


class FlushMode(Enum):
"""
The mode which `AsyncioClient.data()` uses when flushing data frames.
"""

#: Flush all data frames immediately.
IMMEDIATE = 0

#: Flush data frames periodically.
PERIODIC = 1

#: Flush data frames when the user sets an `asyncio.Event`.
MANUAL = 2


class AsyncioClient:
"""Asyncio implementation of a PandABlocks client.
For example::
Expand All @@ -72,7 +88,7 @@ def __init__(self, host: str):

async def connect(self):
"""Connect to the control port, and be ready to handle commands"""
await self._ctrl_stream.connect(self._host, 8888),
await self._ctrl_stream.connect(self._host, 8888)

self._ctrl_task = asyncio.create_task(
self._ctrl_read_forever(self._ctrl_stream.reader)
Expand Down Expand Up @@ -132,48 +148,100 @@ async def data(
scaled: bool = True,
flush_period: Optional[float] = None,
frame_timeout: Optional[float] = None,
flush_event: Optional[asyncio.Event] = None,
flush_mode: FlushMode = FlushMode.IMMEDIATE,
) -> AsyncGenerator[Data, None]:
"""Connect to data port and yield data frames
"""Connect to data port and yield data frames.
Args:
scaled: Whether to scale and average data frames, reduces throughput
flush_period: How often to flush partial data frames, None is on every
chunk of data from the server
flush_period: How often to flush partial data frames when ``flush_mode``
is ``PERIODIC``.
frame_timeout: If no data is received for this amount of time, raise
`asyncio.TimeoutError`
`asyncio.TimeoutError`.
flush_event: An `asyncio.Event` to manually flush. When set while
``flush_mode`` is ``MANUAL`` a flush will be performed and the event
will be unset.
flush_mode: Which mode of the `FlushMode` values.
"""

# Check input
if flush_mode == FlushMode.MANUAL:
if not flush_event:
raise ValueError(
f"flush_event cannot be {flush_event} if flush_mode is 'MANUAL'"
)
if flush_period:
raise ValueError(
f"Unused flush_period={flush_period} "
"inputted during `MANUAL` flush mode"
)
elif flush_mode == FlushMode.PERIODIC:
if not flush_period:
raise ValueError(
f"flush_period cannot be {flush_period} "
"if flush_mode is 'PERIODIC', use `MANUAL` "
"flush_mode to make flushing exclusively "
"manual"
)
elif flush_mode == FlushMode.IMMEDIATE:
if flush_event:
raise ValueError(
f"Unused flush_event={flush_event} "
"inputted during `IMMEDIATE` flush mode"
)
if flush_period:
raise ValueError(
f"Unused flush_period={flush_period} "
"inputted during `IMMEDIATE` flush mode"
)
else:
raise ValueError(
"flush_mode must be one of 'IMMEDIATE', 'PERIODIC', 'MANUAL'"
)

stream = _StreamHelper()
connection = DataConnection()
queue: asyncio.Queue[Iterable[Data]] = asyncio.Queue()
flush_every_frame = flush_mode == FlushMode.IMMEDIATE
flush_event = flush_event or asyncio.Event()

def raise_timeouterror():
raise asyncio.TimeoutError(f"No data received for {frame_timeout}s")
yield

async def periodic_flush():
if flush_period is not None:
while True:
# Every flush_period seconds flush and queue data
await asyncio.sleep(flush_period)
queue.put_nowait(connection.flush())
async def flush_loop():
if flush_every_frame:
# If flush mode is `IMMEDIATE` flushing will be performed
# whenever possible
return
while True:
try:
await asyncio.wait_for(flush_event.wait(), flush_period)
except asyncio.TimeoutError:
pass
else:
flush_event.clear()
queue.put_nowait(connection.flush())

async def read_from_stream():
reader = stream.reader
# Should we flush every FrameData?
flush_every_frame = flush_period is None
while True:
try:
recv = await asyncio.wait_for(reader.read(4096), frame_timeout)
except asyncio.TimeoutError:
queue.put_nowait(raise_timeouterror())
break
else:
queue.put_nowait(connection.receive_bytes(recv, flush_every_frame))
queue.put_nowait(
connection.receive_bytes(
recv, flush_every_frame=flush_every_frame
)
)

await stream.connect(self._host, 8889)
await stream.write_and_drain(connection.connect(scaled))
fut = asyncio.gather(periodic_flush(), read_from_stream())
fut = asyncio.gather(read_from_stream(), flush_loop())
try:
while True:
for data in await queue.get():
Expand Down
6 changes: 4 additions & 2 deletions src/pandablocks/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from pandablocks.commands import Arm

from .asyncio import AsyncioClient
from .asyncio import AsyncioClient, FlushMode
from .connections import SAMPLES_FIELD
from .responses import EndData, EndReason, FieldCapture, FrameData, ReadyData, StartData

Expand Down Expand Up @@ -223,7 +223,9 @@ async def write_hdf_files(
end_data = None
pipeline = create_default_pipeline(file_names)
try:
async for data in client.data(scaled=False, flush_period=flush_period):
async for data in client.data(
scaled=False, flush_period=flush_period, flush_mode=FlushMode.PERIODIC
):
pipeline[0].queue.put_nowait(data)
if type(data) == EndData:
end_data = data
Expand Down
146 changes: 139 additions & 7 deletions tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import pytest

from pandablocks.asyncio import AsyncioClient
from pandablocks.asyncio import AsyncioClient, FlushMode
from pandablocks.commands import CommandException, Get, Put
from pandablocks.responses import EndData, FrameData, ReadyData, StartData

from .conftest import DummyServer

Expand All @@ -30,34 +31,165 @@ async def test_asyncio_bad_put_raises(dummy_server_async):
assert dummy_server_async.received == ["PCAP.thing=1"]


@pytest.mark.asyncio
@pytest.mark.parametrize("disarmed", [True, False])
@pytest.mark.parametrize("flush_period", [0.1, None])
async def test_asyncio_data(
dummy_server_async, fast_dump, fast_dump_expected, disarmed, flush_period
async def test_asyncio_data_periodic_flushing(
dummy_server_async, fast_dump, fast_dump_expected, disarmed
):
if not disarmed:
# simulate getting the data without the END marker as if arm was not pressed
fast_dump = (x.split(b"END")[0] for x in fast_dump)
fast_dump_expected = list(fast_dump_expected)[:-1]

dummy_server_async.data = fast_dump
events = []
async with AsyncioClient("localhost") as client:
async for data in client.data(frame_timeout=1, flush_period=flush_period):
async for data in client.data(
frame_timeout=1, flush_period=0.1, flush_mode=FlushMode.PERIODIC
):
events.append(data)

if len(events) == len(fast_dump_expected):
break

assert fast_dump_expected == events


@pytest.mark.parametrize(
"flush_mode, flush_period",
[
(FlushMode.PERIODIC, float("inf")), # Testing manual flush in PERIODIC mode
(FlushMode.MANUAL, None),
],
)
async def test_asyncio_data_manual_flushing(
dummy_server_async, fast_dump, fast_dump_expected, flush_mode, flush_period
):
dummy_server_async.data = fast_dump

# Button push event
flush_event = asyncio.Event()

async def wait_and_press_button(data_generator):
await asyncio.sleep(0.2)
flush_event.set()
return await data_generator.__anext__()

async with AsyncioClient("localhost") as client:
data_generator = client.data(
frame_timeout=5,
flush_event=flush_event,
flush_mode=flush_mode,
flush_period=flush_period,
)

assert isinstance(await data_generator.__anext__(), ReadyData)
assert isinstance(await data_generator.__anext__(), StartData)
assert isinstance(await wait_and_press_button(data_generator), FrameData)
assert not flush_event.is_set()
assert isinstance(await wait_and_press_button(data_generator), FrameData)
assert not flush_event.is_set()
assert isinstance(await wait_and_press_button(data_generator), FrameData)
assert not flush_event.is_set()
assert isinstance(await wait_and_press_button(data_generator), FrameData)
assert isinstance(await data_generator.__anext__(), EndData)
await data_generator.aclose()


async def test_asyncio_data_flush_every_frame(
dummy_server_async, fast_dump, fast_dump_expected
):
dummy_server_async.data = fast_dump

events = []
async with AsyncioClient("localhost") as client:
async for data in client.data(frame_timeout=5, flush_mode=FlushMode.IMMEDIATE):
events.append(data)

if len(events) == len(fast_dump_expected):
break

assert fast_dump_expected == events


async def test_asyncio_data_timeout_on_no_manual_press(dummy_server_async, fast_dump):
flush_event = asyncio.Event()
async with AsyncioClient("localhost") as client:
with pytest.raises(asyncio.TimeoutError, match="No data received for 3s"):
async for data in client.data(
frame_timeout=3, flush_mode=FlushMode.MANUAL, flush_event=flush_event
):
pass


async def test_asyncio_data_timeout(dummy_server_async, fast_dump):
dummy_server_async.data = fast_dump
async with AsyncioClient("localhost") as client:
with pytest.raises(asyncio.TimeoutError, match="No data received for 0.1s"):
async for data in client.data(frame_timeout=0.1):
async for data in client.data(
frame_timeout=0.1, flush_mode=FlushMode.IMMEDIATE
):
"This goes forever, when it runs out of data we will get our timeout"


async def test_asyncio_data_nonexistent_flushmode(dummy_server_async, fast_dump):
dummy_server_async.data = fast_dump
async with AsyncioClient("localhost") as client:
with pytest.raises(ValueError, match="flush_mode must be one of"):
async for data in client.data(frame_timeout=0.1, flush_mode=None):
pass


@pytest.mark.parametrize(
"kwargs, error",
[
(
{"flush_mode": FlushMode.PERIODIC, "flush_period": 0},
"flush_period cannot be 0",
),
(
{
"flush_mode": FlushMode.PERIODIC,
"flush_period": 0,
"flush_event": asyncio.Event(),
},
" use `MANUAL` flush_mode to make flushing exclusively manual",
),
(
{"flush_mode": FlushMode.PERIODIC, "flush_period": None},
"flush_period cannot be None",
),
(
{"flush_mode": FlushMode.MANUAL, "flush_event": None},
"flush_event cannot be None",
),
(
{
"flush_mode": FlushMode.MANUAL,
"flush_event": asyncio.Event(),
"flush_period": 1,
},
"Unused flush_period=1",
),
(
{"flush_mode": FlushMode.IMMEDIATE, "flush_period": 1},
"Unused flush_period=1",
),
(
{"flush_mode": FlushMode.IMMEDIATE, "flush_event": asyncio.Event()},
"Unused flush_event=<asyncio.locks.Event object",
),
],
)
async def test_asyncio_data_incorrect_arguments(
dummy_server_async, fast_dump, kwargs, error
):
dummy_server_async.data = fast_dump
async with AsyncioClient("localhost") as client:
with pytest.raises(ValueError, match=error):
async for data in client.data(**kwargs):
pass


@pytest.mark.asyncio
async def test_asyncio_connects(dummy_server_async: DummyServer):
async with AsyncioClient("localhost") as client:
Expand Down

0 comments on commit c6ab247

Please sign in to comment.