From c533d2949ffbd1c267c396ddbd161939cd347b19 Mon Sep 17 00:00:00 2001 From: Eva Lott Date: Thu, 18 Jul 2024 15:21:08 +0100 Subject: [PATCH] 87 make hdf datasets always be float64 unless they are pcap bits (#93) * hdf datasets now always use a float64 unless they are PCAP.BITS or PCAP.SAMPLES * also dropped python 3.8 --- .github/pages/make_switcher.py | 12 ++--- .github/workflows/ci.yml | 2 +- pyproject.toml | 4 +- src/pandablocks/_control.py | 12 ++--- src/pandablocks/_exchange.py | 11 +++-- src/pandablocks/asyncio.py | 5 +- src/pandablocks/blocking.py | 5 +- src/pandablocks/cli.py | 6 +-- src/pandablocks/commands.py | 85 ++++++++++++++++------------------ src/pandablocks/connections.py | 35 ++++++++++---- src/pandablocks/hdf.py | 38 +++++++-------- src/pandablocks/responses.py | 77 +++++++++++++++++++----------- src/pandablocks/utils.py | 19 ++++---- tests/conftest.py | 14 +++--- tests/test_hdf.py | 78 +++++++++++++++++++++++++++++++ tests/test_pandablocks.py | 3 +- tests/test_utils.py | 26 +++++------ 17 files changed, 275 insertions(+), 157 deletions(-) diff --git a/.github/pages/make_switcher.py b/.github/pages/make_switcher.py index e2c8e6f62..2b81e7696 100755 --- a/.github/pages/make_switcher.py +++ b/.github/pages/make_switcher.py @@ -3,28 +3,28 @@ from argparse import ArgumentParser from pathlib import Path from subprocess import CalledProcessError, check_output -from typing import List, Optional +from typing import Optional -def report_output(stdout: bytes, label: str) -> List[str]: +def report_output(stdout: bytes, label: str) -> list[str]: ret = stdout.decode().strip().split("\n") print(f"{label}: {ret}") return ret -def get_branch_contents(ref: str) -> List[str]: +def get_branch_contents(ref: str) -> list[str]: """Get the list of directories in a branch.""" stdout = check_output(["git", "ls-tree", "-d", "--name-only", ref]) return report_output(stdout, "Branch contents") -def get_sorted_tags_list() -> List[str]: +def get_sorted_tags_list() -> list[str]: """Get a list of sorted tags in descending order from the repository.""" stdout = check_output(["git", "tag", "-l", "--sort=-v:refname"]) return report_output(stdout, "Tags list") -def get_versions(ref: str, add: Optional[str]) -> List[str]: +def get_versions(ref: str, add: Optional[str]) -> list[str]: """Generate the file containing the list of all GitHub Pages builds.""" # Get the directories (i.e. builds) from the GitHub Pages branch try: @@ -41,7 +41,7 @@ def get_versions(ref: str, add: Optional[str]) -> List[str]: tags = get_sorted_tags_list() # Make the sorted versions list from main branches and tags - versions: List[str] = [] + versions: list[str] = [] for version in ["master", "main"] + tags: if version in builds: versions.append(version) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 888eb4d80..27769757f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,7 +21,7 @@ jobs: strategy: matrix: runs-on: ["ubuntu-latest"] # can add windows-latest, macos-latest - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11"] include: # Include one that runs in the dev environment - runs-on: "ubuntu-latest" diff --git a/pyproject.toml b/pyproject.toml index 73539a517..69bc0d963 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,8 +7,6 @@ name = "pandablocks" classifiers = [ "Development Status :: 5 - Production/Stable", "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -18,7 +16,7 @@ dependencies = ["typing-extensions;python_version<'3.8'", "numpy", "click"] dynamic = ["version"] license.file = "LICENSE" readme = "README.md" -requires-python = ">=3.7" +requires-python = ">=3.9" [project.optional-dependencies] h5py = ["h5py", "matplotlib"] diff --git a/src/pandablocks/_control.py b/src/pandablocks/_control.py index a13dcd825..a020bc759 100644 --- a/src/pandablocks/_control.py +++ b/src/pandablocks/_control.py @@ -1,5 +1,5 @@ from string import digits -from typing import Dict, List, Optional +from typing import Optional from .blocking import BlockingClient from .commands import FieldInfo, GetBlockInfo, GetFieldInfo, Raw, is_multiline_command @@ -25,7 +25,7 @@ STATIC_STAR_COMMANDS.append(f"*CHANGES{suff}=") # Reset reported changes -def _get_user_input(prompt) -> List[str]: +def _get_user_input(prompt) -> list[str]: lines = [input(prompt)] if is_multiline_command(lines[0]): while lines[-1]: @@ -39,7 +39,7 @@ def text_matches(t1, t2): class BlockCompleter: def __init__(self, client: BlockingClient): - self.matches: List[str] = [] + self.matches: list[str] = [] self._client = client self._blocks = self._client.send( GetBlockInfo(skip_description=True), timeout=TIMEOUT @@ -47,14 +47,14 @@ def __init__(self, client: BlockingClient): self._fields = self._get_fields(list(self._blocks)) # TODO: Extend use of _fields now we have enum labels available? - def _get_fields(self, blocks: List[str]) -> Dict[str, Dict[str, FieldInfo]]: + def _get_fields(self, blocks: list[str]) -> dict[str, dict[str, FieldInfo]]: fields = self._client.send( [GetFieldInfo(block, extended_metadata=False) for block in blocks], timeout=TIMEOUT, ) return dict(zip(blocks, fields)) - def _with_suffixes(self, block: str, numbers: bool) -> List[str]: + def _with_suffixes(self, block: str, numbers: bool) -> list[str]: block_info = self._blocks[block] num = block_info.number if numbers and num > 1: @@ -62,7 +62,7 @@ def _with_suffixes(self, block: str, numbers: bool) -> List[str]: else: return [block] - def _block_field_matches(self, text: str, prefix="") -> List[str]: + def _block_field_matches(self, text: str, prefix="") -> list[str]: matches = [] text = text[len(prefix) :] split = text.split(".", maxsplit=1) diff --git a/src/pandablocks/_exchange.py b/src/pandablocks/_exchange.py index f912df21f..f41f24470 100644 --- a/src/pandablocks/_exchange.py +++ b/src/pandablocks/_exchange.py @@ -1,4 +1,5 @@ -from typing import Generator, List, TypeVar, Union +from collections.abc import Generator +from typing import TypeVar, Union T = TypeVar("T") @@ -7,12 +8,12 @@ class Exchange: """A helper class representing the lines to send to PandA and the lines received""" - def __init__(self, to_send: Union[str, List[str]]): + def __init__(self, to_send: Union[str, list[str]]): if isinstance(to_send, str): self.to_send = [to_send] else: self.to_send = to_send - self.received: List[str] = [] + self.received: list[str] = [] self.is_multiline = False @property @@ -22,12 +23,12 @@ def line(self) -> str: return self.received[0] @property - def multiline(self) -> List[str]: + def multiline(self) -> list[str]: """Return the multiline received lines, processed to remove markup""" assert self.is_multiline # Remove the ! and . markup return [line[1:] for line in self.received[:-1]] -Exchanges = Union[Exchange, List[Exchange]] +Exchanges = Union[Exchange, list[Exchange]] ExchangeGenerator = Generator[Exchanges, None, T] diff --git a/src/pandablocks/asyncio.py b/src/pandablocks/asyncio.py index ce793d4ea..9794b142e 100644 --- a/src/pandablocks/asyncio.py +++ b/src/pandablocks/asyncio.py @@ -1,8 +1,9 @@ import asyncio import logging from asyncio.streams import StreamReader, StreamWriter +from collections.abc import AsyncGenerator, Iterable from contextlib import suppress -from typing import AsyncGenerator, Dict, Iterable, Optional +from typing import Optional from .commands import Command, T from .connections import ControlConnection, DataConnection @@ -67,7 +68,7 @@ def __init__(self, host: str): self._host = host self._ctrl_connection = ControlConnection() self._ctrl_task: Optional[asyncio.Task] = None - self._ctrl_queues: Dict[int, asyncio.Queue] = {} + self._ctrl_queues: dict[int, asyncio.Queue] = {} self._ctrl_stream = _StreamHelper() async def connect(self): diff --git a/src/pandablocks/blocking.py b/src/pandablocks/blocking.py index f2c818890..91991e9e9 100644 --- a/src/pandablocks/blocking.py +++ b/src/pandablocks/blocking.py @@ -1,5 +1,6 @@ import socket -from typing import Iterable, Iterator, List, Optional, Union, overload +from collections.abc import Iterable, Iterator +from typing import Optional, Union, overload from .commands import Command, T from .connections import ControlConnection, DataConnection @@ -69,7 +70,7 @@ def send(self, commands: Command[T], timeout: Optional[int] = None) -> T: ... @overload def send( self, commands: Iterable[Command], timeout: Optional[int] = None - ) -> List: ... + ) -> list: ... def send( self, diff --git a/src/pandablocks/cli.py b/src/pandablocks/cli.py index 7d5894fe0..0f3ade86a 100644 --- a/src/pandablocks/cli.py +++ b/src/pandablocks/cli.py @@ -2,7 +2,7 @@ import io import logging import pathlib -from typing import Awaitable, List +from collections.abc import Awaitable import click from click.exceptions import ClickException @@ -69,7 +69,7 @@ def save(host: str, outfile: io.TextIOWrapper): Save the current blocks configuration of HOST to OUTFILE """ - async def _save(host: str) -> List[str]: + async def _save(host: str) -> list[str]: async with AsyncioClient(host) as client: return await client.send(GetState()) @@ -93,7 +93,7 @@ def load(host: str, infile: io.TextIOWrapper, tutorial: bool): else: state = infile.read().splitlines() - async def _load(host: str, state: List[str]): + async def _load(host: str, state: list[str]): async with AsyncioClient(host) as client: await client.send(SetState(state)) diff --git a/src/pandablocks/commands.py b/src/pandablocks/commands.py index e387d2bce..0d1561f77 100644 --- a/src/pandablocks/commands.py +++ b/src/pandablocks/commands.py @@ -1,16 +1,13 @@ import logging import re +from collections.abc import Generator from dataclasses import dataclass, field from enum import Enum from typing import ( Any, Callable, - Dict, - Generator, Generic, - List, Optional, - Tuple, TypeVar, Union, overload, @@ -92,31 +89,31 @@ class CommandException(Exception): # zip() because typing does not support variadic type variables. See # typeshed PR #1550 for discussion. @overload -def _execute_commands(c1: Command[T]) -> ExchangeGenerator[Tuple[T]]: ... +def _execute_commands(c1: Command[T]) -> ExchangeGenerator[tuple[T]]: ... @overload def _execute_commands( c1: Command[T], c2: Command[T2] -) -> ExchangeGenerator[Tuple[T, T2]]: ... +) -> ExchangeGenerator[tuple[T, T2]]: ... @overload def _execute_commands( c1: Command[T], c2: Command[T2], c3: Command[T3] -) -> ExchangeGenerator[Tuple[T, T2, T3]]: ... +) -> ExchangeGenerator[tuple[T, T2, T3]]: ... @overload def _execute_commands( c1: Command[T], c2: Command[T2], c3: Command[T3], c4: Command[T4] -) -> ExchangeGenerator[Tuple[T, T2, T3, T4]]: ... +) -> ExchangeGenerator[tuple[T, T2, T3, T4]]: ... @overload def _execute_commands( *commands: Command[Any], -) -> ExchangeGenerator[Tuple[Any, ...]]: ... +) -> ExchangeGenerator[tuple[Any, ...]]: ... def _execute_commands(*commands): @@ -131,13 +128,13 @@ def _execute_commands(*commands): def _zip_with_return( - generators: List[ExchangeGenerator[Any]], -) -> ExchangeGenerator[Tuple[Any, ...]]: + generators: list[ExchangeGenerator[Any]], +) -> ExchangeGenerator[tuple[Any, ...]]: # Sentinel to show what generators are not yet exhausted pending = object() returns = [pending] * len(generators) while True: - yields: List[Exchange] = [] + yields: list[Exchange] = [] for i, gen in enumerate(generators): # If we haven't exhausted the generator if returns[i] is pending: @@ -164,7 +161,7 @@ def _zip_with_return( @dataclass -class Raw(Command[List[str]]): +class Raw(Command[list[str]]): """Send a raw command Args: @@ -177,16 +174,16 @@ class Raw(Command[List[str]]): Raw(["SEQ1.TABLE?"]) -> ["!1", "!1", "!0", "!0", "."]) """ - inp: List[str] + inp: list[str] - def execute(self) -> ExchangeGenerator[List[str]]: + def execute(self) -> ExchangeGenerator[list[str]]: ex = Exchange(self.inp) yield ex return ex.received @dataclass -class Get(Command[Union[str, List[str]]]): +class Get(Command[Union[str, list[str]]]): """Get the value of a field or star command. If the form of the expected return is known, consider using `GetLine` @@ -204,7 +201,7 @@ class Get(Command[Union[str, List[str]]]): field: str - def execute(self) -> ExchangeGenerator[Union[str, List[str]]]: + def execute(self) -> ExchangeGenerator[Union[str, list[str]]]: ex = Exchange(f"{self.field}?") yield ex if ex.is_multiline: @@ -242,7 +239,7 @@ def execute(self) -> ExchangeGenerator[str]: @dataclass -class GetMultiline(Command[List[str]]): +class GetMultiline(Command[list[str]]): """Get the value of a field or star command, when the result is expected to be a multiline response. @@ -257,7 +254,7 @@ class GetMultiline(Command[List[str]]): field: str - def execute(self) -> ExchangeGenerator[List[str]]: + def execute(self) -> ExchangeGenerator[list[str]]: ex = Exchange(f"{self.field}?") yield ex return ex.multiline @@ -279,7 +276,7 @@ class Put(Command[None]): """ field: str - value: Union[str, List[str]] = "" + value: Union[str, list[str]] = "" def execute(self) -> ExchangeGenerator[None]: if isinstance(self.value, list): @@ -305,7 +302,7 @@ class Append(Command[None]): """ field: str - value: List[str] + value: list[str] def execute(self) -> ExchangeGenerator[None]: # Multiline table with blank line to terminate @@ -333,7 +330,7 @@ def execute(self) -> ExchangeGenerator[None]: @dataclass -class GetBlockInfo(Command[Dict[str, BlockInfo]]): +class GetBlockInfo(Command[dict[str, BlockInfo]]): """Get the name, number, and description of each block type in a dictionary, alphabetically ordered @@ -353,7 +350,7 @@ class GetBlockInfo(Command[Dict[str, BlockInfo]]): skip_description: bool = False - def execute(self) -> ExchangeGenerator[Dict[str, BlockInfo]]: + def execute(self) -> ExchangeGenerator[dict[str, BlockInfo]]: ex = Exchange("*BLOCKS?") yield ex @@ -380,14 +377,14 @@ def execute(self) -> ExchangeGenerator[Dict[str, BlockInfo]]: # The type of the generators used for creating the Get commands for each field # and setting the returned data into the FieldInfo structure _FieldGeneratorType = Generator[ - Union[Exchange, List[Exchange]], + Union[Exchange, list[Exchange]], None, - Tuple[str, FieldInfo], + tuple[str, FieldInfo], ] @dataclass -class GetFieldInfo(Command[Dict[str, FieldInfo]]): +class GetFieldInfo(Command[dict[str, FieldInfo]]): """Get the fields of a block, returning a `FieldInfo` (or appropriate subclass) for each one, ordered to match the definition order in the PandA @@ -412,8 +409,8 @@ class GetFieldInfo(Command[Dict[str, FieldInfo]]): block: str extended_metadata: bool = True - _commands_map: Dict[ - Tuple[str, Optional[str]], + _commands_map: dict[ + tuple[str, Optional[str]], Callable[ [str, str, Optional[str]], _FieldGeneratorType, @@ -585,10 +582,10 @@ def _table( # Keep track of highest bit index max_bit_offset: int = 0 - desc_gets: List[GetLine] = [] - enum_field_gets: List[GetMultiline] = [] - enum_field_names: List[str] = [] - fields_dict: Dict[str, TableFieldDetails] = {} + desc_gets: list[GetLine] = [] + enum_field_gets: list[GetMultiline] = [] + enum_field_names: list[str] = [] + fields_dict: dict[str, TableFieldDetails] = {} for field_details in fields: # Fields are of the form : bit_range, name, subtype = field_details.split() @@ -612,7 +609,7 @@ def _table( # Calculate the number of 32 bit words that comprises one table row row_words = max_bit_offset // 32 + 1 - # The first len(enum_field_gets) items are enum labels, type List[str] + # The first len(enum_field_gets) items are enum labels, type list[str] # The second part of the list are descriptions, type str labels_and_descriptions = yield from _execute_commands( *enum_field_gets, *desc_gets @@ -683,11 +680,11 @@ def _no_attributes( return field_name, FieldInfo(field_type, field_subtype, desc) - def execute(self) -> ExchangeGenerator[Dict[str, FieldInfo]]: + def execute(self) -> ExchangeGenerator[dict[str, FieldInfo]]: ex = Exchange(f"{self.block}.*?") yield ex - unsorted: Dict[int, Tuple[str, FieldInfo]] = {} - field_generators: List[ExchangeGenerator] = [] + unsorted: dict[int, tuple[str, FieldInfo]] = {} + field_generators: list[ExchangeGenerator] = [] for line in ex.multiline: field_name, index, type_subtype = line.split(maxsplit=2) @@ -733,7 +730,7 @@ def execute(self) -> ExchangeGenerator[Dict[str, FieldInfo]]: # Asked to not perform the requests for extra metadata. return fields - field_name_info: Tuple[Tuple[str, FieldInfo], ...] + field_name_info: tuple[tuple[str, FieldInfo], ...] field_name_info = yield from _zip_with_return(field_generators) fields.update(field_name_info) @@ -749,7 +746,7 @@ class GetPcapBitsLabels(Command): GetPcapBitsLabels() -> {"BITS0" : ["TTLIN1.VAL", "TTLIN2.VAL", ...], ...} """ - def execute(self) -> ExchangeGenerator[Dict[str, List[str]]]: + def execute(self) -> ExchangeGenerator[dict[str, list[str]]]: ex = Exchange("PCAP.*?") yield ex bits_fields = [] @@ -827,7 +824,7 @@ def execute(self) -> ExchangeGenerator[Changes]: ex = Exchange(f"*CHANGES{self.group.value}?") yield ex changes = Changes({}, [], [], {}) - multivalue_get_commands: List[Tuple[str, GetMultiline]] = [] + multivalue_get_commands: list[tuple[str, GetMultiline]] = [] for line in ex.multiline: if line[-1] == "<": if self.get_multiline: @@ -857,7 +854,7 @@ def execute(self) -> ExchangeGenerator[Changes]: @dataclass -class GetState(Command[List[str]]): +class GetState(Command[list[str]]): """Get the state of all the fields in a PandA that should be saved as a list of raw lines that could be sent with `SetState`. @@ -875,7 +872,7 @@ class GetState(Command[List[str]]): ] """ - def execute(self) -> ExchangeGenerator[List[str]]: + def execute(self) -> ExchangeGenerator[list[str]]: # TODO: explain in detail how this works # See: references/how-it-works attr, config, table, metadata = yield from _execute_commands( @@ -920,11 +917,11 @@ class SetState(Command[None]): ]) """ - state: List[str] + state: list[str] def execute(self) -> ExchangeGenerator[None]: - commands: List[Raw] = [] - command_lines: List[str] = [] + commands: list[Raw] = [] + command_lines: list[str] = [] for line in self.state: command_lines.append(line) first_line = len(command_lines) == 1 diff --git a/src/pandablocks/connections.py b/src/pandablocks/connections.py index 059f4d0e3..75eeb17af 100644 --- a/src/pandablocks/connections.py +++ b/src/pandablocks/connections.py @@ -2,8 +2,9 @@ import sys import xml.etree.ElementTree as ET from collections import deque +from collections.abc import Iterator from dataclasses import dataclass -from typing import Any, Callable, Deque, Iterator, List, Optional, Tuple +from typing import Any, Callable, Optional import numpy as np @@ -153,11 +154,11 @@ class ControlConnection: def __init__(self) -> None: self._buf = Buffer() - self._lines: List[str] = [] - self._contexts: Deque[_ExchangeContext] = deque() - self._responses: Deque[Tuple[Command, Any]] = deque() + self._lines: list[str] = [] + self._contexts: deque[_ExchangeContext] = deque() + self._responses: deque[tuple[Command, Any]] = deque() - def _update_contexts(self, lines: List[str], is_multiline=False) -> bytes: + def _update_contexts(self, lines: list[str], is_multiline=False) -> bytes: to_send = b"" if len(self._contexts) == 0: raise NoContextAvailable() @@ -234,7 +235,7 @@ def receive_bytes(self, received: bytes) -> bytes: to_send += self._update_contexts([line]) return to_send - def responses(self) -> Iterator[Tuple[Command, Any]]: + def responses(self) -> Iterator[tuple[Command, Any]]: """Get the (command, response) tuples generated as part of the last receive_bytes""" while self._responses: @@ -303,15 +304,22 @@ def _handle_header_body(self): if line == b"": fields = [] root = ET.fromstring(self._header) + for field in root.find("fields"): fields.append( FieldCapture( name=str(field.get("name")), type=np.dtype(field.get("type")), capture=str(field.get("capture")), - scale=float(field.get("scale", 1)), - offset=float(field.get("offset", 0)), - units=str(field.get("units", "")), + scale=float(scale) + if (scale := field.get("scale")) is not None + else None, + offset=float(offset) + if (offset := field.get("offset")) is not None + else None, + units=str(units) + if (units := field.get("units")) is not None + else None, ) ) data = root.find("data") @@ -323,7 +331,14 @@ def _handle_header_body(self): name, capture = SAMPLES_FIELD.rsplit(".", maxsplit=1) fields.insert( 0, - FieldCapture(name, np.dtype("uint32"), capture), + FieldCapture( + name=name, + type=np.dtype("uint32"), + capture=capture, + scale=None, + offset=None, + units=None, + ), ) self._frame_dtype = np.dtype( [(f"{f.name}.{f.capture}", f.type) for f in fields] diff --git a/src/pandablocks/hdf.py b/src/pandablocks/hdf.py index ca7abc7f2..6e25edfbd 100644 --- a/src/pandablocks/hdf.py +++ b/src/pandablocks/hdf.py @@ -1,7 +1,8 @@ import logging import queue import threading -from typing import Any, Callable, Dict, Iterator, List, Optional, Type +from collections.abc import Iterator +from typing import Any, Callable, Optional import h5py import numpy as np @@ -44,7 +45,7 @@ class Pipeline(threading.Thread): #: Subclasses should create this dictionary with handlers for each data #: type, returning transformed data that should be passed downstream - what_to_do: Dict[Type, Callable] + what_to_do: dict[type, Callable] downstream: Optional["Pipeline"] = None def __init__(self): @@ -95,12 +96,12 @@ class HDFWriter(Pipeline): def __init__( self, file_names: Iterator[str], - capture_record_hdf_names: Dict[str, Dict[str, str]], + capture_record_hdf_names: dict[str, dict[str, str]], ): super().__init__() self.file_names = file_names self.hdf_file: Optional[h5py.File] = None - self.datasets: List[h5py.Dataset] = [] + self.datasets: list[h5py.Dataset] = [] self.capture_record_hdf_names = capture_record_hdf_names self.what_to_do = { StartData: self.open_file, @@ -111,16 +112,13 @@ def __init__( def create_dataset(self, field: FieldCapture, raw: bool): # Data written in a big stack, growing in that dimension assert self.hdf_file, "File not open yet" - if raw and (field.capture == "Mean" or field.scale != 1 or field.offset != 0): - # Processor outputs a float - dtype = np.dtype("float64") - else: - # No processor, datatype passed through - dtype = field.type + dataset_name = self.capture_record_hdf_names.get(field.name, {}).get( field.capture, f"{field.name}.{field.capture}" ) + dtype = field.raw_mode_dataset_dtype if raw else field.type + return self.hdf_file.create_dataset( f"/{dataset_name}", dtype=dtype, @@ -154,7 +152,7 @@ def open_file(self, data: StartData): f"stored in {len(self.datasets)} datasets" ) - def write_frame(self, data: List[np.ndarray]): + def write_frame(self, data: list[np.ndarray]): for dataset, column in zip(self.datasets, data): # Append to the end, flush when done written = dataset.shape[0] @@ -181,7 +179,7 @@ class FrameProcessor(Pipeline): def __init__(self) -> None: super().__init__() - self.processors: List[Callable] = [] + self.processors: list[Callable] = [] self.what_to_do = { StartData: self.create_processors, FrameData: self.scale_data, @@ -201,7 +199,7 @@ def mean_callable(data): return (data[column_name] * field.scale / gate_duration) + field.offset return mean_callable - elif raw and (field.scale != 1 or field.offset != 0): + elif raw and field.has_scale_or_offset: return lambda data: data[column_name] * field.scale + field.offset else: return lambda data: data[column_name] @@ -211,15 +209,15 @@ def create_processors(self, data: StartData) -> StartData: self.processors = [self.create_processor(field, raw) for field in data.fields] return data - def scale_data(self, data: FrameData) -> List[np.ndarray]: + def scale_data(self, data: FrameData) -> list[np.ndarray]: return [process(data.data) for process in self.processors] def create_default_pipeline( file_names: Iterator[str], - capture_record_hdf_names: Dict[str, Dict[str, str]], + capture_record_hdf_names: dict[str, dict[str, str]], *additional_downstream_pipelines: Pipeline, -) -> List[Pipeline]: +) -> list[Pipeline]: """Create the default processing pipeline consisting of one `FrameProcessor` and one `HDFWriter`. See `create_pipeline` for more details. @@ -240,10 +238,10 @@ def create_default_pipeline( ) -def create_pipeline(*elements: Pipeline) -> List[Pipeline]: +def create_pipeline(*elements: Pipeline) -> list[Pipeline]: """Create a pipeline of elements, wiring them and starting them before returning them""" - pipeline: List[Pipeline] = [] + pipeline: list[Pipeline] = [] for element in elements: if pipeline: pipeline[-1].downstream = element @@ -252,7 +250,7 @@ def create_pipeline(*elements: Pipeline) -> List[Pipeline]: return pipeline -def stop_pipeline(pipeline: List[Pipeline]): +def stop_pipeline(pipeline: list[Pipeline]): """Stop and join each element of the pipeline""" for element in pipeline: # Note that we stop and join each element in turn. @@ -289,7 +287,7 @@ async def write_hdf_files( try: async for data in client.data(scaled=False, flush_period=flush_period): pipeline[0].queue.put_nowait(data) - if type(data) == EndData: + if isinstance(data, EndData): end_data = data counter += 1 if counter == num: diff --git a/src/pandablocks/responses.py b/src/pandablocks/responses.py index a68bb2984..01d2e558d 100644 --- a/src/pandablocks/responses.py +++ b/src/pandablocks/responses.py @@ -1,6 +1,6 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum -from typing import Dict, List, Optional, Tuple +from typing import Optional import numpy as np @@ -89,7 +89,7 @@ class ScalarFieldInfo(FieldInfo): class TimeFieldInfo(FieldInfo): """Extended `FieldInfo` for fields with type "time""" - units_labels: List[str] + units_labels: list[str] min_val: float @@ -98,7 +98,7 @@ class SubtypeTimeFieldInfo(FieldInfo): """Extended `FieldInfo` for fields with type "param","read", or "write" and subtype "time""" - units_labels: List[str] + units_labels: list[str] @dataclass @@ -106,7 +106,7 @@ class EnumFieldInfo(FieldInfo): """Extended `FieldInfo` for fields with type "param","read", or "write" and subtype "enum""" - labels: List[str] + labels: list[str] @dataclass @@ -122,14 +122,14 @@ class BitMuxFieldInfo(FieldInfo): """Extended `FieldInfo` for fields with type "bit_mux""" max_delay: int - labels: List[str] + labels: list[str] @dataclass class PosMuxFieldInfo(FieldInfo): """Extended `FieldInfo` for fields with type "pos_mux""" - labels: List[str] + labels: list[str] @dataclass @@ -140,7 +140,7 @@ class TableFieldDetails: bit_low: int bit_high: int description: Optional[str] = None - labels: Optional[List[str]] = None + labels: Optional[list[str]] = None @dataclass @@ -148,7 +148,7 @@ class TableFieldInfo(FieldInfo): """Extended `FieldInfo` for fields with type "table""" max_length: int - fields: Dict[str, TableFieldDetails] + fields: dict[str, TableFieldDetails] row_words: int @@ -156,7 +156,7 @@ class TableFieldInfo(FieldInfo): class PosOutFieldInfo(FieldInfo): """Extended `FieldInfo` for fields with type "pos_out""" - capture_labels: List[str] + capture_labels: list[str] @dataclass @@ -164,14 +164,14 @@ class ExtOutFieldInfo(FieldInfo): """Extended `FieldInfo` for fields with type "ext_out" and subtypes "timestamp" or "samples""" - capture_labels: List[str] + capture_labels: list[str] @dataclass class ExtOutBitsFieldInfo(ExtOutFieldInfo): """Extended `ExtOutFieldInfo` for fields with type "ext_out" and subtype "bits""" - bits: List[str] + bits: list[str] @dataclass @@ -179,13 +179,13 @@ class Changes: """The changes returned from a ``*CHANGES`` command""" #: Map field -> value for single-line values that were returned - values: Dict[str, str] + values: dict[str, str] #: The fields that were present but without value - no_value: List[str] + no_value: list[str] #: The fields that were in error - in_error: List[str] + in_error: list[str] #: Map field -> value for multi-line values that were returned - multiline_values: Dict[str, List[str]] + multiline_values: dict[str, list[str]] # Data @@ -223,21 +223,46 @@ class EndReason(Enum): class FieldCapture: """Information about a field that is being captured + If scale, offset, and units are all `None`, then the field is a + ``PCAP.BITS``. + Attributes: name: Name of captured field type: Numpy data type of the field as transmitted capture: Value of CAPTURE field used to enable this field - scale: Scaling factor, default 1.0 - offset: Offset, default 0.0 - units: Units string, default "" + scale: Scaling factor + offset: Offset + units: Units string """ name: str type: np.dtype capture: str - scale: float = 1.0 - offset: float = 0.0 - units: str = "" + scale: Optional[float] = field(default=None) + offset: Optional[float] = field(default=None) + units: Optional[str] = field(default=None) + + def __post_init__(self): + sou = (self.scale, self.offset, self.units) + if sou != (None, None, None) and None in sou: + raise ValueError( + f"If any of `scale={self.scale}`, `offset={self.offset}`" + f", or `units={self.units}` is set, all must be set." + ) + + @property + def raw_mode_dataset_dtype(self) -> np.dtype: + """We use double for all dtypes that have scale and offset.""" + if self.scale is not None and self.offset is not None: + return np.dtype("float64") + return self.type + + @property + def has_scale_or_offset(self) -> bool: + """Return True if this field is a PCAP.BITS or PCAP.SAMPLES field""" + return (self.scale is not None and self.offset is not None) and ( + self.scale != 1 or self.offset != 0 + ) class Data: @@ -261,7 +286,7 @@ class StartData(Data): sample_bytes: Number of bytes in one sample """ - fields: List[FieldCapture] + fields: list[FieldCapture] missed: int process: str format: str @@ -289,8 +314,8 @@ class FrameData(Data): ... (2, 12)], ... dtype=[('COUNTER1.OUT.Value', '>> fdata = FrameData(data) - >>> fdata.data[0] # Row view - (0., 10.) + >>> (fdata.data[0]['COUNTER1.OUT.Value'], fdata.data[0]['COUNTER2.OUT.Value']) + (np.float64(0.0), np.float64(10.0)) >>> fdata.column_names # Column names ('COUNTER1.OUT.Value', 'COUNTER2.OUT.Value') >>> fdata.data['COUNTER1.OUT.Value'] # Column view @@ -300,7 +325,7 @@ class FrameData(Data): data: np.ndarray @property - def column_names(self) -> Tuple[str, ...]: + def column_names(self) -> tuple[str, ...]: """Return all the column names""" names = self.data.dtype.names assert names, f"No column names for {self.data.dtype}" diff --git a/src/pandablocks/utils.py b/src/pandablocks/utils.py index 233995065..e34e69df6 100644 --- a/src/pandablocks/utils.py +++ b/src/pandablocks/utils.py @@ -1,4 +1,5 @@ -from typing import Dict, Iterable, List, Union, cast +from collections.abc import Iterable +from typing import Union, cast import numpy as np import numpy.typing as npt @@ -10,7 +11,7 @@ npt.NDArray[np.uint16], npt.NDArray[np.int32], npt.NDArray[np.uint32], - List[str], + list[str], ] @@ -18,7 +19,7 @@ def words_to_table( words: Iterable[str], table_field_info: TableFieldInfo, convert_enum_indices: bool = False, -) -> Dict[str, UnpackedArray]: +) -> dict[str, UnpackedArray]: """Unpacks the given `packed` data based on the fields provided. Returns the unpacked data in {column_name: column_data} column-indexed format @@ -41,7 +42,7 @@ def words_to_table( data = data.reshape(len(data) // row_words, row_words) packed = data.T - unpacked: Dict[str, UnpackedArray] = {} + unpacked: dict[str, UnpackedArray] = {} for field_name, field_info in table_field_info.fields.items(): offset = field_info.bit_low @@ -62,7 +63,9 @@ def words_to_table( if field_info.subtype == "int": # First convert from 2's complement to offset, then add in offset. - temp = (value ^ (1 << (bit_length - 1))) + (-1 << (bit_length - 1)) + temp = (value.astype(np.int64) ^ (1 << (bit_length - 1))) + ( + -1 << (bit_length - 1) + ) packing_value = temp.astype(np.int32) elif field_info.subtype == "enum" and convert_enum_indices: assert field_info.labels, f"Enum field {field_name} has no labels" @@ -82,8 +85,8 @@ def words_to_table( def table_to_words( - table: Dict[str, UnpackedArray], table_field_info: TableFieldInfo -) -> List[str]: + table: dict[str, UnpackedArray], table_field_info: TableFieldInfo +) -> list[str]: """Convert records based on the field definitions into the format PandA expects for table writes. @@ -93,7 +96,7 @@ def table_to_words( table_field_info: The info for tables, containing the dict `fields` for information on each field, and the number of words per row. Returns: - List[str]: The list of data ready to be sent to PandA + list[str]: The list of data ready to be sent to PandA """ row_words = table_field_info.row_words diff --git a/tests/conftest.py b/tests/conftest.py index 1ed7d365b..b9e894f7e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,9 +2,9 @@ import os import threading from collections import deque +from collections.abc import Iterable, Iterator from io import BufferedReader from pathlib import Path -from typing import Deque, Dict, Iterable, Iterator, List import numpy as np import pytest @@ -88,9 +88,9 @@ def overrun_dump(): name="PCAP.BITS2", type=np.dtype("uint32"), capture="Value", - scale=1, - offset=0, - units="", + scale=None, + offset=None, + units=None, ), FieldCapture( name="COUNTER1.OUT", @@ -281,13 +281,13 @@ class DummyServer: # when it sees an expected string. When the expected message is seen the # response will be left-appended to the send buffer so it is sent next. # Items are removed from the Dict when they are sent. - expected_message_responses: Dict[str, str] = {} + expected_message_responses: dict[str, str] = {} def __init__(self) -> None: # This will be added to whenever control port gets a message - self.received: List[str] = [] + self.received: list[str] = [] # Add to this to give the control port something to send back - self.send: Deque[str] = deque() + self.send: deque[str] = deque() # Add to this to give the data port something to send self.data: Iterable[bytes] = [] diff --git a/tests/test_hdf.py b/tests/test_hdf.py index 290491aca..298aa6c77 100644 --- a/tests/test_hdf.py +++ b/tests/test_hdf.py @@ -61,6 +61,84 @@ def __init__(self): stop_pipeline(pipeline) +def test_field_capture_pcap_bits(): + pcap_bits_frame_data = FieldCapture( + name="PCAP.BITS", + type=np.dtype("uint32"), + capture="Value", + scale=None, + offset=None, + units=None, + ) + + assert not pcap_bits_frame_data.has_scale_or_offset + assert pcap_bits_frame_data.raw_mode_dataset_dtype is np.dtype("uint32") + + frame_data_without_scale_offset = FieldCapture( + name="frame_data_without_scale_offset", + type=np.dtype("uint32"), + capture="Value", + scale=1.0, + offset=0.0, + units="", + ) + + assert not frame_data_without_scale_offset.has_scale_or_offset + assert frame_data_without_scale_offset.raw_mode_dataset_dtype is np.dtype("float64") + + with pytest.raises( + ValueError, + match=( + "If any of `scale=None`, `offset=0.0`, or " + "`units=` is set, all must be set" + ), + ): + _ = FieldCapture( + name="malformed_frame_data", + type=np.dtype("uint32"), + capture="Value", + scale=None, + offset=0.0, + units="", + ) + + frame_data_with_offset = FieldCapture( + name="frame_data_with_offset", + type=np.dtype("float64"), + capture="Value", + scale=1.0, + offset=1.0, + units="", + ) + frame_data_with_scale = FieldCapture( + name="frame_data_with_scale", + type=np.dtype("float64"), + capture="Value", + scale=1.1, + offset=0.0, + units="", + ) + + assert frame_data_with_offset.has_scale_or_offset + assert frame_data_with_offset.raw_mode_dataset_dtype is np.dtype("float64") + assert frame_data_with_scale.has_scale_or_offset + assert frame_data_with_scale.raw_mode_dataset_dtype is np.dtype("float64") + + frame_data_with_scale_and_offset = FieldCapture( + name="frame_data_with_scale_and_offset", + type=np.dtype("float64"), + capture="Value", + scale=1.1, + offset=0.0, + units="", + ) + + assert frame_data_with_scale_and_offset.has_scale_or_offset + assert frame_data_with_scale_and_offset.raw_mode_dataset_dtype is np.dtype( + "float64" + ) + + @pytest.mark.parametrize( "capture_record_hdf_names,expected_names", [ diff --git a/tests/test_pandablocks.py b/tests/test_pandablocks.py index ab588bfa0..c06565dd0 100644 --- a/tests/test_pandablocks.py +++ b/tests/test_pandablocks.py @@ -1,4 +1,5 @@ -from typing import Iterator, OrderedDict +from collections import OrderedDict +from collections.abc import Iterator import pytest diff --git a/tests/test_utils.py b/tests/test_utils.py index e9424b34c..29b6a710f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,4 @@ -from typing import Dict, List, OrderedDict +from collections import OrderedDict import numpy as np import pytest @@ -8,7 +8,7 @@ @pytest.fixture -def table_fields() -> Dict[str, TableFieldDetails]: +def table_fields() -> dict[str, TableFieldDetails]: """Table field definitions, taken from a SEQ.TABLE instance. Associated with table_data and table_field_info fixtures""" return { @@ -211,7 +211,7 @@ def table_1_np_arrays_int_enums() -> OrderedDict[str, UnpackedArray]: @pytest.fixture -def table_data_1() -> List[str]: +def table_data_1() -> list[str]: return [ "2457862149", "4294967291", @@ -229,8 +229,8 @@ def table_data_1() -> List[str]: @pytest.fixture -def table_2_np_arrays() -> Dict[str, UnpackedArray]: - table: Dict[str, UnpackedArray] = { +def table_2_np_arrays() -> dict[str, UnpackedArray]: + table: dict[str, UnpackedArray] = { "REPEATS": np.array([1, 0], dtype=np.uint32), "TRIGGER": ["Immediate", "Immediate"], "POSITION": np.array([-20, 2**31 - 1], dtype=np.int32), @@ -247,7 +247,7 @@ def table_2_np_arrays() -> Dict[str, UnpackedArray]: @pytest.fixture -def table_data_2() -> List[str]: +def table_data_2() -> list[str]: return [ "67108865", "4294967276", @@ -288,8 +288,8 @@ def test_table_to_words_and_words_to_table( table_field_info: TableFieldInfo, request, ): - table: Dict[str, UnpackedArray] = request.getfixturevalue(table_fixture_name) - table_data: List[str] = request.getfixturevalue(table_data_fixture_name) + table: dict[str, UnpackedArray] = request.getfixturevalue(table_fixture_name) + table_data: list[str] = request.getfixturevalue(table_data_fixture_name) output_data = table_to_words(table, table_field_info) assert output_data == table_data @@ -312,7 +312,7 @@ def test_table_to_words_and_words_to_table( def test_table_packing_unpack( table_1_np_arrays: OrderedDict[str, np.ndarray], table_field_info: TableFieldInfo, - table_data_1: List[str], + table_data_1: list[str], ): assert table_field_info.row_words output_table = words_to_table( @@ -328,7 +328,7 @@ def test_table_packing_unpack( def test_table_packing_unpack_no_convert_enum( table_1_np_arrays_int_enums: OrderedDict[str, UnpackedArray], table_field_info: TableFieldInfo, - table_data_1: List[str], + table_data_1: list[str], ): assert table_field_info.row_words output_table = words_to_table(table_data_1, table_field_info) @@ -340,9 +340,9 @@ def test_table_packing_unpack_no_convert_enum( def test_table_packing_pack( - table_1_np_arrays: Dict[str, UnpackedArray], + table_1_np_arrays: dict[str, UnpackedArray], table_field_info: TableFieldInfo, - table_data_1: List[str], + table_data_1: list[str], ): assert table_field_info.row_words unpacked = table_to_words(table_1_np_arrays, table_field_info) @@ -352,7 +352,7 @@ def test_table_packing_pack( def test_table_packing_give_default_values( - table_1_np_arrays: Dict[str, UnpackedArray], + table_1_np_arrays: dict[str, UnpackedArray], table_field_info: TableFieldInfo, ): # We should have a complete table at the point of unpacking