From 2495813cc47d74630d6e63d5962730701641935e Mon Sep 17 00:00:00 2001 From: Jan Janssen Date: Sat, 17 Aug 2024 16:00:29 +0200 Subject: [PATCH] Fix compatibility for read_nested_dict_from_hdf() (#61) * Fix compatibility for read_nested_dict_from_hdf() * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- h5io_browser/base.py | 110 ++++++++++++++++++++++++++++--------------- tests/test_base.py | 39 +++++++++++++++ 2 files changed, 111 insertions(+), 38 deletions(-) diff --git a/h5io_browser/base.py b/h5io_browser/base.py index 8272178..faad8cd 100644 --- a/h5io_browser/base.py +++ b/h5io_browser/base.py @@ -21,6 +21,7 @@ "csr_matrix", "csc_array", "csr_array", + "dict", "multiarray", ) @@ -81,19 +82,9 @@ def read_dict_from_hdf(file_name, h5_path, recursive=False, slash="ignore"): any type supported by ``write_hdf5``. """ with h5py.File(file_name, "r") as hdf: - if recursive: - nodes_lst = _get_hdf_content( - hdf=hdf[h5_path], recursive=recursive, only_nodes=True - ) - else: - nodes_lst = [h5_path] - if len(nodes_lst) > 0 and nodes_lst[0] != "/": - return { - n: _read_hdf(hdf_filehandle=hdf, h5_path=n, slash=slash) - for n in nodes_lst - } - else: - return {} + return _read_dict_from_open_hdf( + hdf_filehandle=hdf, h5_path=h5_path, recursive=recursive, slash=slash + ) def read_nested_dict_from_hdf( @@ -119,32 +110,44 @@ def read_nested_dict_from_hdf( if h5_path[0] != "/": h5_path = "/" + h5_path with h5py.File(file_name, "r") as hdf: - nodes_lst = _get_hdf_content( - hdf=hdf[h5_path], recursive=recursive, only_nodes=True - ) - if not recursive and len(nodes_lst) == 0 and h5_path != "/": - nodes_lst += [h5_path] - if len(group_paths) > 0: - for group in group_paths: - nodes_lst += _get_hdf_content( - hdf=hdf[posixpath.join(h5_path, group)], - recursive=recursive, - only_nodes=True, - ) - if len(nodes_lst) > 0: - return_dict = {} - for n in nodes_lst: - return_dict = _merge_nested_dict( - main_dict=return_dict, - add_dict=_get_nested_dict_item( - key=n, - value=_read_hdf(hdf_filehandle=hdf, h5_path=n, slash=slash), - h5_path=h5_path, - ), - ) - return return_dict + group_attrs_dict = hdf[h5_path].attrs + if ( + "TITLE" in group_attrs_dict.keys() + and group_attrs_dict["TITLE"] in H5IO_GROUP_TYPES + ): + return _read_dict_from_open_hdf( + hdf_filehandle=hdf, + h5_path=h5_path[1:], + recursive=recursive, + slash=slash, + ) else: - return {} + nodes_lst = _get_hdf_content( + hdf=hdf[h5_path], recursive=recursive, only_nodes=True + ) + if not recursive and len(nodes_lst) == 0 and h5_path != "/": + nodes_lst += [h5_path] + if len(group_paths) > 0: + for group in group_paths: + nodes_lst += _get_hdf_content( + hdf=hdf[posixpath.join(h5_path, group)], + recursive=recursive, + only_nodes=True, + ) + if len(nodes_lst) > 0: + return_dict = {} + for n in nodes_lst: + return_dict = _merge_nested_dict( + main_dict=return_dict, + add_dict=_get_nested_dict_item( + key=n, + value=_read_hdf(hdf_filehandle=hdf, h5_path=n, slash=slash), + h5_path=h5_path, + ), + ) + return return_dict + else: + return {} def write_dict_to_hdf(file_name, data_dict, compression=4, slash="error"): @@ -293,6 +296,37 @@ def _read_hdf(hdf_filehandle, h5_path, slash="ignore"): ) +def _read_dict_from_open_hdf(hdf_filehandle, h5_path, recursive=False, slash="ignore"): + """ + Read data from an open HDF5 file into a dictionary - by default only the nodes are converted to dictionaries, + additional sub groups can be converted using the recursive parameter. + + Args: + hdf_filehandle (h5py.File): Open HDF5 file + h5_path (str): Path to a group in the HDF5 file from where the data is read + recursive (bool/int): Recursively browse through the HDF5 file, either a boolean flag or an integer + which specifies the level of recursion. + slash (str): 'ignore' | 'replace' Whether to replace the string {FWDSLASH} with the value /. This does + not apply to the top level name (title). If 'ignore', nothing will be replaced. + Returns: + dict: The loaded data as dictionary, with the keys being the path inside the HDF5 file. The values can be of + any type supported by ``write_hdf5``. + """ + if recursive: + nodes_lst = _get_hdf_content( + hdf=hdf_filehandle[h5_path], recursive=recursive, only_nodes=True + ) + else: + nodes_lst = [h5_path] + if len(nodes_lst) > 0 and nodes_lst[0] != "/": + return { + n: _read_hdf(hdf_filehandle=hdf_filehandle, h5_path=n, slash=slash) + for n in nodes_lst + } + else: + return {} + + def _write_hdf( hdf_filehandle, h5_path, diff --git a/tests/test_base.py b/tests/test_base.py index 917c178..398dd72 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -3,6 +3,7 @@ import h5py from unittest import TestCase import posixpath +import h5io from h5io_browser import ( delete_item, list_hdf, @@ -398,3 +399,41 @@ def test_delete(self): nodes, groups = list_hdf(file_name=self.file_name, h5_path="/data_json") self.assertEqual(groups, []) self.assertEqual(nodes, ["/data_json/a"]) + + +class TestCompatibility(TestCase): + def setUp(self): + self.file_name = "testcomp.h5" + self.data = { + "array": np.ones(4) * 42, + "b": 42, + } + self.h5_path = "h5io" + h5io.write_hdf5("testcomp.h5", self.data) + + def test_h5io(self): + dataread = h5io.read_hdf5(self.file_name, self.h5_path) + for k, v in self.data.items(): + if isinstance(v, np.ndarray): + self.assertTrue(all(np.equal(v, dataread[k]))) + else: + self.assertTrue(v == dataread[k]) + + def test_read_dict_from_hdf(self): + dataread = read_dict_from_hdf(self.file_name, self.h5_path) + for k, v in self.data.items(): + if isinstance(v, np.ndarray): + self.assertTrue(all(np.equal(v, dataread[self.h5_path][k]))) + else: + self.assertTrue(v == dataread[self.h5_path][k]) + + def test_read_nested_dict_from_hdf(self): + dataread = read_nested_dict_from_hdf(self.file_name, self.h5_path) + for k, v in self.data.items(): + if isinstance(v, np.ndarray): + self.assertTrue(all(np.equal(v, dataread[self.h5_path][k]))) + else: + self.assertTrue(v == dataread[self.h5_path][k]) + + def tearDown(self): + os.remove(self.file_name)