From bdadd048e0694e0d450540a140fa81c907591ecf Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Tue, 21 Nov 2023 09:44:19 +0100 Subject: [PATCH] Fix indexing error for 0-dimensional HDF5 datasets * Refactor ProfileResultHDF5Writer to be more readable * Raise more informative error if writing to HDF5 fails * Don't index 0-dimensional datasets (Fixes #1205) --- pypesto/result/profile.py | 2 +- pypesto/store/hdf5.py | 12 ++++++++--- pypesto/store/save_to_hdf5.py | 39 ++++++++++++++++++++++++----------- 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/pypesto/result/profile.py b/pypesto/result/profile.py index 0d19ef871..33be1da0a 100644 --- a/pypesto/result/profile.py +++ b/pypesto/result/profile.py @@ -183,7 +183,7 @@ class ProfileResult: """ def __init__(self): - self.list = [] + self.list: list[list[ProfilerResult]] = [] def append_empty_profile_list(self) -> int: """Append an empty profile list to the list of profile lists. diff --git a/pypesto/store/hdf5.py b/pypesto/store/hdf5.py index 259e745ae..9d2b738a1 100644 --- a/pypesto/store/hdf5.py +++ b/pypesto/store/hdf5.py @@ -45,7 +45,9 @@ def write_string_array(f: h5py.Group, path: str, strings: Collection) -> None: """ dt = h5py.special_dtype(vlen=str) dset = f.create_dataset(path, (len(strings),), dtype=dt) - dset[:] = [s.encode('utf8') for s in strings] + + if len(strings): + dset[:] = [s.encode('utf8') for s in strings] def write_float_array( @@ -69,7 +71,9 @@ def write_float_array( dset = f.create_dataset(path, (np.shape(values)), dtype=dtype) else: dset = f[path] - dset[:] = values + + if len(values): + dset[:] = values def write_int_array( @@ -90,4 +94,6 @@ def write_int_array( datatype """ dset = f.create_dataset(path, (len(values),), dtype=dtype) - dset[:] = values + + if len(values): + dset[:] = values diff --git a/pypesto/store/save_to_hdf5.py b/pypesto/store/save_to_hdf5.py index 86e116e8d..be6b80b29 100644 --- a/pypesto/store/save_to_hdf5.py +++ b/pypesto/store/save_to_hdf5.py @@ -8,7 +8,7 @@ import h5py import numpy as np -from ..result import Result, SampleResult +from ..result import ProfilerResult, Result, SampleResult from .hdf5 import write_array, write_float_array logger = logging.getLogger(__name__) @@ -234,20 +234,35 @@ def write(self, result: Result, overwrite: bool = False): profile_grp = profiling_grp.require_group(str(profile_id)) for parameter_id, parameter_profile in enumerate(profile): result_grp = profile_grp.require_group(str(parameter_id)) + self._write_profiler_result(parameter_profile, result_grp) - if parameter_profile is None: - result_grp.attrs['IsNone'] = True - continue - result_grp.attrs['IsNone'] = False - for key in parameter_profile.keys(): - if isinstance(parameter_profile[key], np.ndarray): - write_float_array( - result_grp, key, parameter_profile[key] - ) - elif parameter_profile[key] is not None: - result_grp.attrs[key] = parameter_profile[key] f.flush() + @staticmethod + def _write_profiler_result( + parameter_profile: Union[ProfilerResult, None], result_grp: h5py.Group + ) -> None: + """Write a single ProfilerResult to hdf5. + + Writes a single profile for a single parameter to the provided HDF5 group. + """ + if parameter_profile is None: + result_grp.attrs['IsNone'] = True + return + + result_grp.attrs['IsNone'] = False + + for key, value in parameter_profile.items(): + try: + if isinstance(value, np.ndarray): + write_float_array(result_grp, key, value) + elif value is not None: + result_grp.attrs[key] = value + except Exception as e: + raise ValueError( + f"Error writing {key} ({value}) to {result_grp}." + ) from e + def write_result( result: Result,