diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index a4e1e252ea..a45c7e1df7 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -33,7 +33,7 @@ from zarr.store.common import ensure_no_existing_node if TYPE_CHECKING: - from collections.abc import AsyncGenerator, Iterable, Iterator + from collections.abc import AsyncGenerator, Generator, Iterable, Iterator from typing import Any from zarr.abc.codec import Codec @@ -678,29 +678,31 @@ async def contains(self, member: str) -> bool: else: return True - # todo: decide if this method should be separate from `groups` - async def group_keys(self) -> AsyncGenerator[str, None]: - async for key, value in self.members(): + async def groups(self) -> AsyncGenerator[tuple[str, AsyncGroup], None]: + async for name, value in self.members(): if isinstance(value, AsyncGroup): - yield key + yield name, value - # todo: decide if this method should be separate from `group_keys` - async def groups(self) -> AsyncGenerator[AsyncGroup, None]: - async for _, value in self.members(): - if isinstance(value, AsyncGroup): - yield value + async def group_keys(self) -> AsyncGenerator[str, None]: + async for key, _ in self.groups(): + yield key - # todo: decide if this method should be separate from `arrays` - async def array_keys(self) -> AsyncGenerator[str, None]: + async def group_values(self) -> AsyncGenerator[AsyncGroup, None]: + async for _, group in self.groups(): + yield group + + async def arrays(self) -> AsyncGenerator[tuple[str, AsyncArray], None]: async for key, value in self.members(): if isinstance(value, AsyncArray): - yield key + yield key, value - # todo: decide if this method should be separate from `array_keys` - async def arrays(self) -> AsyncGenerator[AsyncArray, None]: - async for _, value in self.members(): - if isinstance(value, AsyncArray): - yield value + async def array_keys(self) -> AsyncGenerator[str, None]: + async for key, _ in self.arrays(): + yield key + + async def array_values(self) -> AsyncGenerator[AsyncArray, None]: + async for _, array in self.arrays(): + yield array async def tree(self, expand: bool = False, level: int | None = None) -> Any: raise NotImplementedError @@ -861,18 +863,29 @@ def members(self, max_depth: int | None = 0) -> tuple[tuple[str, Array | Group], def __contains__(self, member: str) -> bool: return self._sync(self._async_group.contains(member)) - def group_keys(self) -> tuple[str, ...]: - return tuple(self._sync_iter(self._async_group.group_keys())) + def groups(self) -> Generator[tuple[str, Group], None]: + for name, async_group in self._sync_iter(self._async_group.groups()): + yield name, Group(async_group) + + def group_keys(self) -> Generator[str, None]: + for name, _ in self.groups(): + yield name + + def group_values(self) -> Generator[Group, None]: + for _, group in self.groups(): + yield group - def groups(self) -> tuple[Group, ...]: - # TODO: in v2 this was a generator that return key: Group - return tuple(Group(obj) for obj in self._sync_iter(self._async_group.groups())) + def arrays(self) -> Generator[tuple[str, Array], None]: + for name, async_array in self._sync_iter(self._async_group.arrays()): + yield name, Array(async_array) - def array_keys(self) -> tuple[str, ...]: - return tuple(self._sync_iter(self._async_group.array_keys())) + def array_keys(self) -> Generator[str, None]: + for name, _ in self.arrays(): + yield name - def arrays(self) -> tuple[Array, ...]: - return tuple(Array(obj) for obj in self._sync_iter(self._async_group.arrays())) + def array_values(self) -> Generator[Array, None]: + for _, array in self.arrays(): + yield array def tree(self, expand: bool = False, level: int | None = None) -> Any: return self._sync(self._async_group.tree(expand=expand, level=level)) diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index 4bb23fddaa..6e75294b7c 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -301,34 +301,28 @@ def test_group_contains(store: Store, zarr_format: ZarrFormat) -> None: assert "foo" in group -def test_group_subgroups(store: Store, zarr_format: ZarrFormat) -> None: - """ - Test the behavior of `Group` methods for accessing subgroups, namely `Group.group_keys` and `Group.groups` - """ +def test_group_child_iterators(store: Store, zarr_format: ZarrFormat): group = Group.create(store, zarr_format=zarr_format) - keys = ("foo", "bar") - subgroups_expected = tuple(group.create_group(k) for k in keys) - # create a sub-array as well - _ = group.create_array("array", shape=(10,)) - subgroups_observed = group.groups() - assert set(group.group_keys()) == set(keys) - assert len(subgroups_observed) == len(subgroups_expected) - assert all(a in subgroups_observed for a in subgroups_expected) + expected_group_keys = ["g0", "g1"] + expected_group_values = [group.create_group(name=name) for name in expected_group_keys] + expected_groups = list(zip(expected_group_keys, expected_group_values, strict=False)) + expected_group_values[0].create_group("subgroup") + expected_group_values[0].create_array("subarray", shape=(1,)) -def test_group_subarrays(store: Store, zarr_format: ZarrFormat) -> None: - """ - Test the behavior of `Group` methods for accessing subgroups, namely `Group.group_keys` and `Group.groups` - """ - group = Group.create(store, zarr_format=zarr_format) - keys = ("foo", "bar") - subarrays_expected = tuple(group.create_array(k, shape=(10,)) for k in keys) - # create a sub-group as well - _ = group.create_group("group") - subarrays_observed = group.arrays() - assert set(group.array_keys()) == set(keys) - assert len(subarrays_observed) == len(subarrays_expected) - assert all(a in subarrays_observed for a in subarrays_expected) + expected_array_keys = ["a0", "a1"] + expected_array_values = [ + group.create_array(name=name, shape=(1,)) for name in expected_array_keys + ] + expected_arrays = list(zip(expected_array_keys, expected_array_values, strict=False)) + + assert sorted(group.groups(), key=lambda x: x[0]) == expected_groups + assert sorted(group.group_keys()) == expected_group_keys + assert sorted(group.group_values(), key=lambda x: x.name) == expected_group_values + + assert sorted(group.arrays(), key=lambda x: x[0]) == expected_arrays + assert sorted(group.array_keys()) == expected_array_keys + assert sorted(group.array_values(), key=lambda x: x.name) == expected_array_values def test_group_update_attributes(store: Store, zarr_format: ZarrFormat) -> None: