Skip to content

Commit

Permalink
Use low level interface for read_n_rows
Browse files Browse the repository at this point in the history
  • Loading branch information
iguinn committed Oct 24, 2024
1 parent ed7a879 commit 21bd790
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 46 deletions.
63 changes: 62 additions & 1 deletion src/lgdo/lh5/_serializers/read/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from __future__ import annotations

import logging

import h5py
import numpy as np

from .... import types
from ... import datatype
from ...exceptions import LH5DecodeError

log = logging.getLogger(__name__)


def check_obj_buf_attrs(attrs, new_attrs, fname, oname):
if set(attrs.keys()) != set(new_attrs.keys()):
Expand All @@ -23,7 +29,7 @@ def read_attrs(h5o, fname, oname):
h5a = h5py.h5a.open(h5o, index=i_attr)
name = h5a.get_name().decode()
if h5a.shape != ():
msg = f"attribute {name} is not a string or scalar"
msg = f"attribute {oname} is not a string or scalar"
raise LH5DecodeError(msg, fname, oname)
val = np.empty((), h5a.dtype)
h5a.read(val)
Expand All @@ -33,3 +39,58 @@ def read_attrs(h5o, fname, oname):
attrs[name] = val.item()
h5a.close()
return attrs


def read_n_rows(h5o, fname, oname):
"""Read number of rows in LH5 object"""
if not h5py.h5a.exists(h5o, b"datatype"):
msg = "missing 'datatype' attribute"
raise LH5DecodeError(msg, fname, oname)

h5a = h5py.h5a.open(h5o, b"datatype")
type_attr = np.empty((), h5a.dtype)
h5a.read(type_attr)
type_attr = type_attr.item().decode()
lgdotype = datatype.datatype(type_attr)

# scalars are dim-0 datasets
if lgdotype is types.Scalar:
return None

# structs don't have rows
if lgdotype is types.Struct:
return None

# tables should have elements with all the same length
if lgdotype is types.Table:
# read out each of the fields
rows_read = None
for field in datatype.get_struct_fields(type_attr):
n_rows_read = read_n_rows(h5py.h5o.open(h5o, field.encode()), fname, field)
if not rows_read:
rows_read = n_rows_read
elif rows_read != n_rows_read:
log.warning(
f"'{field}' field in table '{oname}' has {rows_read} rows, "
f"{n_rows_read} was expected"
)

return rows_read

# length of vector of vectors is the length of its cumulative_length
if lgdotype is types.VectorOfVectors:
return read_n_rows(
h5py.h5o.open(h5o, b"cumulative_length"), fname, "cumulative_length"
)

# length of vector of encoded vectors is the length of its decoded_size
if lgdotype in (types.VectorOfEncodedVectors, types.ArrayOfEncodedEqualSizedArrays):
return read_n_rows(h5py.h5o.open(h5o, b"encoded_data"), fname, "encoded_data")

# return array length (without reading the array!)
if issubclass(lgdotype, types.Array):
# compute the number of rows to read
return h5o.get_space().shape[0]

msg = f"don't know how to read rows of LGDO {lgdotype.__name__}"
raise LH5DecodeError(msg, fname, oname)
48 changes: 3 additions & 45 deletions src/lgdo/lh5/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import h5py

from .. import types
from . import _serializers, datatype
from . import _serializers
from .exceptions import LH5DecodeError

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -47,54 +47,12 @@ def read_n_rows(name: str, h5f: str | h5py.File) -> int | None:
h5f = h5py.File(h5f, "r")

try:
attrs = h5f[name].attrs
h5o = h5f[name].id
except KeyError as e:
msg = "not found"
raise LH5DecodeError(msg, h5f, name) from e
except AttributeError as e:
msg = "missing 'datatype' attribute"
raise LH5DecodeError(msg, h5f, name) from e

lgdotype = datatype.datatype(attrs["datatype"])

# scalars are dim-0 datasets
if lgdotype is types.Scalar:
return None

# structs don't have rows
if lgdotype is types.Struct:
return None

# tables should have elements with all the same length
if lgdotype is types.Table:
# read out each of the fields
rows_read = None
for field in datatype.get_struct_fields(attrs["datatype"]):
n_rows_read = read_n_rows(name + "/" + field, h5f)
if not rows_read:
rows_read = n_rows_read
elif rows_read != n_rows_read:
log.warning(
f"'{field}' field in table '{name}' has {rows_read} rows, "
f"{n_rows_read} was expected"
)
return rows_read

# length of vector of vectors is the length of its cumulative_length
if lgdotype is types.VectorOfVectors:
return read_n_rows(f"{name}/cumulative_length", h5f)

# length of vector of encoded vectors is the length of its decoded_size
if lgdotype in (types.VectorOfEncodedVectors, types.ArrayOfEncodedEqualSizedArrays):
return read_n_rows(f"{name}/encoded_data", h5f)

# return array length (without reading the array!)
if issubclass(lgdotype, types.Array):
# compute the number of rows to read
return h5f[name].shape[0]

msg = f"don't know how to read rows of LGDO {lgdotype.__name__}"
raise LH5DecodeError(msg, h5f, name)
return _serializers.read.utils.read_n_rows(h5o, h5f.name, name)


def get_h5_group(
Expand Down

0 comments on commit 21bd790

Please sign in to comment.