diff --git a/src/zarr/api/asynchronous.py b/src/zarr/api/asynchronous.py index 11e4219c0e..5239927aad 100644 --- a/src/zarr/api/asynchronous.py +++ b/src/zarr/api/asynchronous.py @@ -14,6 +14,7 @@ from zarr.core.group import AsyncGroup from zarr.core.metadata.v2 import ArrayV2Metadata from zarr.core.metadata.v3 import ArrayV3Metadata +from zarr.errors import NodeTypeValidationError from zarr.storage import ( StoreLike, StorePath, @@ -247,7 +248,10 @@ async def open( try: return await open_array(store=store_path, zarr_format=zarr_format, **kwargs) - except KeyError: + except (KeyError, NodeTypeValidationError): + # KeyError for a missing key + # NodeTypeValidationError for failing to parse node metadata as an array when it's + # actually a group return await open_group(store=store_path, zarr_format=zarr_format, **kwargs) @@ -580,6 +584,8 @@ async def open_group( meta_array : array-like, optional An array instance to use for determining arrays to create and return to users. Use `numpy.empty(())` by default. + attributes : dict + A dictionary of JSON-serializable values with user-defined attributes. Returns ------- diff --git a/src/zarr/api/synchronous.py b/src/zarr/api/synchronous.py index f5d614058a..d76216b781 100644 --- a/src/zarr/api/synchronous.py +++ b/src/zarr/api/synchronous.py @@ -207,6 +207,7 @@ def open_group( zarr_version: ZarrFormat | None = None, # deprecated zarr_format: ZarrFormat | None = None, meta_array: Any | None = None, # not used in async api + attributes: dict[str, JSON] | None = None, ) -> Group: return Group( sync( @@ -221,6 +222,7 @@ def open_group( zarr_version=zarr_version, zarr_format=zarr_format, meta_array=meta_array, + attributes=attributes, ) ) ) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 8d63d9c321..fdf9aa623e 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -67,6 +67,7 @@ from zarr.core.metadata.v2 import ArrayV2Metadata from zarr.core.metadata.v3 import ArrayV3Metadata from zarr.core.sync import collect_aiterator, sync +from zarr.errors import MetadataValidationError from zarr.registry import get_pipeline_class from zarr.storage import StoreLike, make_store_path from zarr.storage.common import StorePath, ensure_no_existing_node @@ -144,7 +145,7 @@ async def get_array_metadata( else: zarr_format = 2 else: - raise ValueError(f"unexpected zarr_format: {zarr_format}") + raise MetadataValidationError("zarr_format", "2, 3, or None", zarr_format) metadata_dict: dict[str, Any] if zarr_format == 2: diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index bf0e385d06..4d43570325 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -29,6 +29,7 @@ ) from zarr.core.config import config from zarr.core.sync import SyncMixin, sync +from zarr.errors import MetadataValidationError from zarr.storage import StoreLike, make_store_path from zarr.storage.common import StorePath, ensure_no_existing_node @@ -196,7 +197,7 @@ async def open( else: zarr_format = 2 else: - raise ValueError(f"unexpected zarr_format: {zarr_format}") + raise MetadataValidationError("zarr_format", "2, 3, or None", zarr_format) if zarr_format == 2: # V2 groups are comprised of a .zgroup and .zattrs objects diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 47c6106bfe..12e26746df 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -29,6 +29,7 @@ from zarr.core.config import config from zarr.core.metadata.common import ArrayMetadata, parse_attributes from zarr.core.strings import _STRING_DTYPE as STRING_NP_DTYPE +from zarr.errors import MetadataValidationError, NodeTypeValidationError from zarr.registry import get_codec_class DEFAULT_DTYPE = "float64" @@ -37,13 +38,13 @@ def parse_zarr_format(data: object) -> Literal[3]: if data == 3: return 3 - raise ValueError(f"Invalid value. Expected 3. Got {data}.") + raise MetadataValidationError("zarr_format", 3, data) def parse_node_type_array(data: object) -> Literal["array"]: if data == "array": return "array" - raise ValueError(f"Invalid value. Expected 'array'. Got {data}.") + raise NodeTypeValidationError("node_type", "array", data) def parse_codecs(data: object) -> tuple[Codec, ...]: diff --git a/src/zarr/errors.py b/src/zarr/errors.py index 72efcedf2e..e6d416bcc6 100644 --- a/src/zarr/errors.py +++ b/src/zarr/errors.py @@ -25,6 +25,21 @@ class ContainsArrayAndGroupError(_BaseZarrError): ) +class MetadataValidationError(_BaseZarrError): + """An exception raised when the Zarr metadata is invalid in some way""" + + _msg = "Invalid value for '{}'. Expected '{}'. Got '{}'." + + +class NodeTypeValidationError(MetadataValidationError): + """ + Specialized exception when the node_type of the metadata document is incorrect.. + + This can be raised when the value is invalid or unexpected given the context, + for example an 'array' node when we expected a 'group'. + """ + + __all__ = [ "ContainsArrayAndGroupError", "ContainsArrayError", diff --git a/tests/v3/test_api.py b/tests/v3/test_api.py index 218aec5c97..0614185f68 100644 --- a/tests/v3/test_api.py +++ b/tests/v3/test_api.py @@ -6,10 +6,13 @@ from numpy.testing import assert_array_equal import zarr +import zarr.api.asynchronous +import zarr.core.group from zarr import Array, Group from zarr.abc.store import Store from zarr.api.synchronous import create, group, load, open, open_group, save, save_array, save_group from zarr.core.common import ZarrFormat +from zarr.errors import MetadataValidationError from zarr.storage.memory import MemoryStore @@ -921,3 +924,37 @@ def test_open_group_positional_args_deprecated() -> None: store = MemoryStore({}, mode="w") with pytest.warns(FutureWarning, match="pass"): open_group(store, "w") + + +def test_open_falls_back_to_open_group() -> None: + # https://github.com/zarr-developers/zarr-python/issues/2309 + store = MemoryStore(mode="w") + zarr.open_group(store, attributes={"key": "value"}) + + group = zarr.open(store) + assert isinstance(group, Group) + assert group.attrs == {"key": "value"} + + +async def test_open_falls_back_to_open_group_async() -> None: + # https://github.com/zarr-developers/zarr-python/issues/2309 + store = MemoryStore(mode="w") + await zarr.api.asynchronous.open_group(store, attributes={"key": "value"}) + + group = await zarr.api.asynchronous.open(store=store) + assert isinstance(group, zarr.core.group.AsyncGroup) + assert group.attrs == {"key": "value"} + + +async def test_metadata_validation_error() -> None: + with pytest.raises( + MetadataValidationError, + match="Invalid value for 'zarr_format'. Expected '2, 3, or None'. Got '3.0'.", + ): + await zarr.api.asynchronous.open_group(zarr_format="3.0") # type: ignore[arg-type] + + with pytest.raises( + MetadataValidationError, + match="Invalid value for 'zarr_format'. Expected '2, 3, or None'. Got '3.0'.", + ): + await zarr.api.asynchronous.open_array(shape=(1,), zarr_format="3.0") # type: ignore[arg-type] diff --git a/tests/v3/test_metadata/test_v3.py b/tests/v3/test_metadata/test_v3.py index b55b2d7a65..f3b2968680 100644 --- a/tests/v3/test_metadata/test_v3.py +++ b/tests/v3/test_metadata/test_v3.py @@ -24,6 +24,7 @@ default_fill_value, parse_dimension_names, parse_fill_value, + parse_node_type_array, parse_zarr_format, ) @@ -54,7 +55,9 @@ @pytest.mark.parametrize("data", [None, 1, 2, 4, 5, "3"]) def test_parse_zarr_format_invalid(data: Any) -> None: - with pytest.raises(ValueError, match=f"Invalid value. Expected 3. Got {data}"): + with pytest.raises( + ValueError, match=f"Invalid value for 'zarr_format'. Expected '3'. Got '{data}'." + ): parse_zarr_format(data) @@ -62,6 +65,18 @@ def test_parse_zarr_format_valid() -> None: assert parse_zarr_format(3) == 3 +@pytest.mark.parametrize("data", [None, "group"]) +def test_parse_node_type_arrayinvalid(data: Any) -> None: + with pytest.raises( + ValueError, match=f"Invalid value for 'node_type'. Expected 'array'. Got '{data}'." + ): + parse_node_type_array(data) + + +def test_parse_node_typevalid() -> None: + assert parse_node_type_array("array") == "array" + + @pytest.mark.parametrize("data", [(), [1, 2, "a"], {"foo": 10}]) def parse_dimension_names_invalid(data: Any) -> None: with pytest.raises(TypeError, match="Expected either None or iterable of str,"):