Skip to content

Commit

Permalink
zarr.open should fall back to opening a group (#2310)
Browse files Browse the repository at this point in the history
* zarr.open should fall back to opening a group

Closes #2309

* fixup

* robuster

* fixup test

* Consistent version error

* more
  • Loading branch information
TomAugspurger authored Oct 10, 2024
1 parent 81a87d6 commit 395604d
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 6 deletions.
8 changes: 7 additions & 1 deletion src/zarr/api/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from zarr.core.group import AsyncGroup
from zarr.core.metadata.v2 import ArrayV2Metadata
from zarr.core.metadata.v3 import ArrayV3Metadata
from zarr.errors import NodeTypeValidationError
from zarr.storage import (
StoreLike,
StorePath,
Expand Down Expand Up @@ -247,7 +248,10 @@ async def open(

try:
return await open_array(store=store_path, zarr_format=zarr_format, **kwargs)
except KeyError:
except (KeyError, NodeTypeValidationError):
# KeyError for a missing key
# NodeTypeValidationError for failing to parse node metadata as an array when it's
# actually a group
return await open_group(store=store_path, zarr_format=zarr_format, **kwargs)


Expand Down Expand Up @@ -580,6 +584,8 @@ async def open_group(
meta_array : array-like, optional
An array instance to use for determining arrays to create and return
to users. Use `numpy.empty(())` by default.
attributes : dict
A dictionary of JSON-serializable values with user-defined attributes.
Returns
-------
Expand Down
2 changes: 2 additions & 0 deletions src/zarr/api/synchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def open_group(
zarr_version: ZarrFormat | None = None, # deprecated
zarr_format: ZarrFormat | None = None,
meta_array: Any | None = None, # not used in async api
attributes: dict[str, JSON] | None = None,
) -> Group:
return Group(
sync(
Expand All @@ -221,6 +222,7 @@ def open_group(
zarr_version=zarr_version,
zarr_format=zarr_format,
meta_array=meta_array,
attributes=attributes,
)
)
)
Expand Down
3 changes: 2 additions & 1 deletion src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from zarr.core.metadata.v2 import ArrayV2Metadata
from zarr.core.metadata.v3 import ArrayV3Metadata
from zarr.core.sync import collect_aiterator, sync
from zarr.errors import MetadataValidationError
from zarr.registry import get_pipeline_class
from zarr.storage import StoreLike, make_store_path
from zarr.storage.common import StorePath, ensure_no_existing_node
Expand Down Expand Up @@ -144,7 +145,7 @@ async def get_array_metadata(
else:
zarr_format = 2
else:
raise ValueError(f"unexpected zarr_format: {zarr_format}")
raise MetadataValidationError("zarr_format", "2, 3, or None", zarr_format)

metadata_dict: dict[str, Any]
if zarr_format == 2:
Expand Down
3 changes: 2 additions & 1 deletion src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from zarr.core.config import config
from zarr.core.sync import SyncMixin, sync
from zarr.errors import MetadataValidationError
from zarr.storage import StoreLike, make_store_path
from zarr.storage.common import StorePath, ensure_no_existing_node

Expand Down Expand Up @@ -196,7 +197,7 @@ async def open(
else:
zarr_format = 2
else:
raise ValueError(f"unexpected zarr_format: {zarr_format}")
raise MetadataValidationError("zarr_format", "2, 3, or None", zarr_format)

if zarr_format == 2:
# V2 groups are comprised of a .zgroup and .zattrs objects
Expand Down
5 changes: 3 additions & 2 deletions src/zarr/core/metadata/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from zarr.core.config import config
from zarr.core.metadata.common import ArrayMetadata, parse_attributes
from zarr.core.strings import _STRING_DTYPE as STRING_NP_DTYPE
from zarr.errors import MetadataValidationError, NodeTypeValidationError
from zarr.registry import get_codec_class

DEFAULT_DTYPE = "float64"
Expand All @@ -37,13 +38,13 @@
def parse_zarr_format(data: object) -> Literal[3]:
if data == 3:
return 3
raise ValueError(f"Invalid value. Expected 3. Got {data}.")
raise MetadataValidationError("zarr_format", 3, data)


def parse_node_type_array(data: object) -> Literal["array"]:
if data == "array":
return "array"
raise ValueError(f"Invalid value. Expected 'array'. Got {data}.")
raise NodeTypeValidationError("node_type", "array", data)


def parse_codecs(data: object) -> tuple[Codec, ...]:
Expand Down
15 changes: 15 additions & 0 deletions src/zarr/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,21 @@ class ContainsArrayAndGroupError(_BaseZarrError):
)


class MetadataValidationError(_BaseZarrError):
"""An exception raised when the Zarr metadata is invalid in some way"""

_msg = "Invalid value for '{}'. Expected '{}'. Got '{}'."


class NodeTypeValidationError(MetadataValidationError):
"""
Specialized exception when the node_type of the metadata document is incorrect..
This can be raised when the value is invalid or unexpected given the context,
for example an 'array' node when we expected a 'group'.
"""


__all__ = [
"ContainsArrayAndGroupError",
"ContainsArrayError",
Expand Down
37 changes: 37 additions & 0 deletions tests/v3/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
from numpy.testing import assert_array_equal

import zarr
import zarr.api.asynchronous
import zarr.core.group
from zarr import Array, Group
from zarr.abc.store import Store
from zarr.api.synchronous import create, group, load, open, open_group, save, save_array, save_group
from zarr.core.common import ZarrFormat
from zarr.errors import MetadataValidationError
from zarr.storage.memory import MemoryStore


Expand Down Expand Up @@ -921,3 +924,37 @@ def test_open_group_positional_args_deprecated() -> None:
store = MemoryStore({}, mode="w")
with pytest.warns(FutureWarning, match="pass"):
open_group(store, "w")


def test_open_falls_back_to_open_group() -> None:
# https://github.com/zarr-developers/zarr-python/issues/2309
store = MemoryStore(mode="w")
zarr.open_group(store, attributes={"key": "value"})

group = zarr.open(store)
assert isinstance(group, Group)
assert group.attrs == {"key": "value"}


async def test_open_falls_back_to_open_group_async() -> None:
# https://github.com/zarr-developers/zarr-python/issues/2309
store = MemoryStore(mode="w")
await zarr.api.asynchronous.open_group(store, attributes={"key": "value"})

group = await zarr.api.asynchronous.open(store=store)
assert isinstance(group, zarr.core.group.AsyncGroup)
assert group.attrs == {"key": "value"}


async def test_metadata_validation_error() -> None:
with pytest.raises(
MetadataValidationError,
match="Invalid value for 'zarr_format'. Expected '2, 3, or None'. Got '3.0'.",
):
await zarr.api.asynchronous.open_group(zarr_format="3.0") # type: ignore[arg-type]

with pytest.raises(
MetadataValidationError,
match="Invalid value for 'zarr_format'. Expected '2, 3, or None'. Got '3.0'.",
):
await zarr.api.asynchronous.open_array(shape=(1,), zarr_format="3.0") # type: ignore[arg-type]
17 changes: 16 additions & 1 deletion tests/v3/test_metadata/test_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
default_fill_value,
parse_dimension_names,
parse_fill_value,
parse_node_type_array,
parse_zarr_format,
)

Expand Down Expand Up @@ -54,14 +55,28 @@

@pytest.mark.parametrize("data", [None, 1, 2, 4, 5, "3"])
def test_parse_zarr_format_invalid(data: Any) -> None:
with pytest.raises(ValueError, match=f"Invalid value. Expected 3. Got {data}"):
with pytest.raises(
ValueError, match=f"Invalid value for 'zarr_format'. Expected '3'. Got '{data}'."
):
parse_zarr_format(data)


def test_parse_zarr_format_valid() -> None:
assert parse_zarr_format(3) == 3


@pytest.mark.parametrize("data", [None, "group"])
def test_parse_node_type_arrayinvalid(data: Any) -> None:
with pytest.raises(
ValueError, match=f"Invalid value for 'node_type'. Expected 'array'. Got '{data}'."
):
parse_node_type_array(data)


def test_parse_node_typevalid() -> None:
assert parse_node_type_array("array") == "array"


@pytest.mark.parametrize("data", [(), [1, 2, "a"], {"foo": 10}])
def parse_dimension_names_invalid(data: Any) -> None:
with pytest.raises(TypeError, match="Expected either None or iterable of str,"):
Expand Down

0 comments on commit 395604d

Please sign in to comment.