diff --git a/src/lgdo/lh5/_serializers/read/utils.py b/src/lgdo/lh5/_serializers/read/utils.py index a83be2be..5474cca3 100644 --- a/src/lgdo/lh5/_serializers/read/utils.py +++ b/src/lgdo/lh5/_serializers/read/utils.py @@ -1,10 +1,16 @@ from __future__ import annotations +import logging + import h5py import numpy as np +from .... import types +from ... import datatype from ...exceptions import LH5DecodeError +log = logging.getLogger(__name__) + def check_obj_buf_attrs(attrs, new_attrs, fname, oname): if set(attrs.keys()) != set(new_attrs.keys()): @@ -23,7 +29,7 @@ def read_attrs(h5o, fname, oname): h5a = h5py.h5a.open(h5o, index=i_attr) name = h5a.get_name().decode() if h5a.shape != (): - msg = f"attribute {name} is not a string or scalar" + msg = f"attribute {oname} is not a string or scalar" raise LH5DecodeError(msg, fname, oname) val = np.empty((), h5a.dtype) h5a.read(val) @@ -33,3 +39,58 @@ def read_attrs(h5o, fname, oname): attrs[name] = val.item() h5a.close() return attrs + + +def read_n_rows(h5o, fname, oname): + """Read number of rows in LH5 object""" + if not h5py.h5a.exists(h5o, b"datatype"): + msg = "missing 'datatype' attribute" + raise LH5DecodeError(msg, fname, oname) + + h5a = h5py.h5a.open(h5o, b"datatype") + type_attr = np.empty((), h5a.dtype) + h5a.read(type_attr) + type_attr = type_attr.item().decode() + lgdotype = datatype.datatype(type_attr) + + # scalars are dim-0 datasets + if lgdotype is types.Scalar: + return None + + # structs don't have rows + if lgdotype is types.Struct: + return None + + # tables should have elements with all the same length + if lgdotype is types.Table: + # read out each of the fields + rows_read = None + for field in datatype.get_struct_fields(type_attr): + n_rows_read = read_n_rows(h5py.h5o.open(h5o, field.encode()), fname, field) + if not rows_read: + rows_read = n_rows_read + elif rows_read != n_rows_read: + log.warning( + f"'{field}' field in table '{oname}' has {rows_read} rows, " + f"{n_rows_read} was expected" + ) + + return rows_read + + # length of vector of vectors is the length of its cumulative_length + if lgdotype is types.VectorOfVectors: + return read_n_rows( + h5py.h5o.open(h5o, b"cumulative_length"), fname, "cumulative_length" + ) + + # length of vector of encoded vectors is the length of its decoded_size + if lgdotype in (types.VectorOfEncodedVectors, types.ArrayOfEncodedEqualSizedArrays): + return read_n_rows(h5py.h5o.open(h5o, b"encoded_data"), fname, "encoded_data") + + # return array length (without reading the array!) + if issubclass(lgdotype, types.Array): + # compute the number of rows to read + return h5o.get_space().shape[0] + + msg = f"don't know how to read rows of LGDO {lgdotype.__name__}" + raise LH5DecodeError(msg, fname, oname) diff --git a/src/lgdo/lh5/utils.py b/src/lgdo/lh5/utils.py index ceb3b7ac..a19a0ed8 100644 --- a/src/lgdo/lh5/utils.py +++ b/src/lgdo/lh5/utils.py @@ -12,7 +12,7 @@ import h5py from .. import types -from . import _serializers, datatype +from . import _serializers from .exceptions import LH5DecodeError log = logging.getLogger(__name__) @@ -47,54 +47,12 @@ def read_n_rows(name: str, h5f: str | h5py.File) -> int | None: h5f = h5py.File(h5f, "r") try: - attrs = h5f[name].attrs + h5o = h5f[name].id except KeyError as e: msg = "not found" raise LH5DecodeError(msg, h5f, name) from e - except AttributeError as e: - msg = "missing 'datatype' attribute" - raise LH5DecodeError(msg, h5f, name) from e - - lgdotype = datatype.datatype(attrs["datatype"]) - - # scalars are dim-0 datasets - if lgdotype is types.Scalar: - return None - - # structs don't have rows - if lgdotype is types.Struct: - return None - - # tables should have elements with all the same length - if lgdotype is types.Table: - # read out each of the fields - rows_read = None - for field in datatype.get_struct_fields(attrs["datatype"]): - n_rows_read = read_n_rows(name + "/" + field, h5f) - if not rows_read: - rows_read = n_rows_read - elif rows_read != n_rows_read: - log.warning( - f"'{field}' field in table '{name}' has {rows_read} rows, " - f"{n_rows_read} was expected" - ) - return rows_read - - # length of vector of vectors is the length of its cumulative_length - if lgdotype is types.VectorOfVectors: - return read_n_rows(f"{name}/cumulative_length", h5f) - - # length of vector of encoded vectors is the length of its decoded_size - if lgdotype in (types.VectorOfEncodedVectors, types.ArrayOfEncodedEqualSizedArrays): - return read_n_rows(f"{name}/encoded_data", h5f) - - # return array length (without reading the array!) - if issubclass(lgdotype, types.Array): - # compute the number of rows to read - return h5f[name].shape[0] - msg = f"don't know how to read rows of LGDO {lgdotype.__name__}" - raise LH5DecodeError(msg, h5f, name) + return _serializers.read.utils.read_n_rows(h5o, h5f.name, name) def get_h5_group(