Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Group.arrays, groups compatible with v2 #2213

Merged
merged 2 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 40 additions & 27 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from zarr.store.common import ensure_no_existing_node

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Iterable, Iterator
from collections.abc import AsyncGenerator, Generator, Iterable, Iterator
from typing import Any

from zarr.abc.codec import Codec
Expand Down Expand Up @@ -678,29 +678,31 @@ async def contains(self, member: str) -> bool:
else:
return True

# todo: decide if this method should be separate from `groups`
async def group_keys(self) -> AsyncGenerator[str, None]:
async for key, value in self.members():
async def groups(self) -> AsyncGenerator[tuple[str, AsyncGroup], None]:
async for name, value in self.members():
if isinstance(value, AsyncGroup):
yield key
yield name, value

# todo: decide if this method should be separate from `group_keys`
async def groups(self) -> AsyncGenerator[AsyncGroup, None]:
async for _, value in self.members():
if isinstance(value, AsyncGroup):
yield value
async def group_keys(self) -> AsyncGenerator[str, None]:
async for key, _ in self.groups():
yield key

# todo: decide if this method should be separate from `arrays`
async def array_keys(self) -> AsyncGenerator[str, None]:
async def group_values(self) -> AsyncGenerator[AsyncGroup, None]:
async for _, group in self.groups():
yield group

async def arrays(self) -> AsyncGenerator[tuple[str, AsyncArray], None]:
async for key, value in self.members():
if isinstance(value, AsyncArray):
yield key
yield key, value

# todo: decide if this method should be separate from `array_keys`
async def arrays(self) -> AsyncGenerator[AsyncArray, None]:
async for _, value in self.members():
if isinstance(value, AsyncArray):
yield value
async def array_keys(self) -> AsyncGenerator[str, None]:
async for key, _ in self.arrays():
yield key

async def array_values(self) -> AsyncGenerator[AsyncArray, None]:
async for _, array in self.arrays():
yield array

async def tree(self, expand: bool = False, level: int | None = None) -> Any:
raise NotImplementedError
Expand Down Expand Up @@ -861,18 +863,29 @@ def members(self, max_depth: int | None = 0) -> tuple[tuple[str, Array | Group],
def __contains__(self, member: str) -> bool:
return self._sync(self._async_group.contains(member))

def group_keys(self) -> tuple[str, ...]:
return tuple(self._sync_iter(self._async_group.group_keys()))
def groups(self) -> Generator[tuple[str, Group], None]:
for name, async_group in self._sync_iter(self._async_group.groups()):
yield name, Group(async_group)

def group_keys(self) -> Generator[str, None]:
for name, _ in self.groups():
yield name

def group_values(self) -> Generator[Group, None]:
for _, group in self.groups():
yield group

def groups(self) -> tuple[Group, ...]:
# TODO: in v2 this was a generator that return key: Group
return tuple(Group(obj) for obj in self._sync_iter(self._async_group.groups()))
def arrays(self) -> Generator[tuple[str, Array], None]:
for name, async_array in self._sync_iter(self._async_group.arrays()):
yield name, Array(async_array)

def array_keys(self) -> tuple[str, ...]:
return tuple(self._sync_iter(self._async_group.array_keys()))
def array_keys(self) -> Generator[str, None]:
for name, _ in self.arrays():
yield name

def arrays(self) -> tuple[Array, ...]:
return tuple(Array(obj) for obj in self._sync_iter(self._async_group.arrays()))
def array_values(self) -> Generator[Array, None]:
for _, array in self.arrays():
yield array

def tree(self, expand: bool = False, level: int | None = None) -> Any:
return self._sync(self._async_group.tree(expand=expand, level=level))
Expand Down
44 changes: 19 additions & 25 deletions tests/v3/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,34 +301,28 @@ def test_group_contains(store: Store, zarr_format: ZarrFormat) -> None:
assert "foo" in group


def test_group_subgroups(store: Store, zarr_format: ZarrFormat) -> None:
"""
Test the behavior of `Group` methods for accessing subgroups, namely `Group.group_keys` and `Group.groups`
"""
def test_group_child_iterators(store: Store, zarr_format: ZarrFormat):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to combine these tests since the setup is the same. And the test benefits from having both groups and arrays in the Group being accessed, since the code needs to filter to just child Groups / Arrays.

group = Group.create(store, zarr_format=zarr_format)
keys = ("foo", "bar")
subgroups_expected = tuple(group.create_group(k) for k in keys)
# create a sub-array as well
_ = group.create_array("array", shape=(10,))
subgroups_observed = group.groups()
assert set(group.group_keys()) == set(keys)
assert len(subgroups_observed) == len(subgroups_expected)
assert all(a in subgroups_observed for a in subgroups_expected)
expected_group_keys = ["g0", "g1"]
expected_group_values = [group.create_group(name=name) for name in expected_group_keys]
expected_groups = list(zip(expected_group_keys, expected_group_values, strict=False))

expected_group_values[0].create_group("subgroup")
expected_group_values[0].create_array("subarray", shape=(1,))

def test_group_subarrays(store: Store, zarr_format: ZarrFormat) -> None:
"""
Test the behavior of `Group` methods for accessing subgroups, namely `Group.group_keys` and `Group.groups`
"""
group = Group.create(store, zarr_format=zarr_format)
keys = ("foo", "bar")
subarrays_expected = tuple(group.create_array(k, shape=(10,)) for k in keys)
# create a sub-group as well
_ = group.create_group("group")
subarrays_observed = group.arrays()
assert set(group.array_keys()) == set(keys)
assert len(subarrays_observed) == len(subarrays_expected)
assert all(a in subarrays_observed for a in subarrays_expected)
expected_array_keys = ["a0", "a1"]
expected_array_values = [
group.create_array(name=name, shape=(1,)) for name in expected_array_keys
]
expected_arrays = list(zip(expected_array_keys, expected_array_values, strict=False))

assert sorted(group.groups(), key=lambda x: x[0]) == expected_groups
assert sorted(group.group_keys()) == expected_group_keys
assert sorted(group.group_values(), key=lambda x: x.name) == expected_group_values

assert sorted(group.arrays(), key=lambda x: x[0]) == expected_arrays
assert sorted(group.array_keys()) == expected_array_keys
assert sorted(group.array_values(), key=lambda x: x.name) == expected_array_values


def test_group_update_attributes(store: Store, zarr_format: ZarrFormat) -> None:
Expand Down