Skip to content

Commit

Permalink
Add array_metadata strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Feb 10, 2025
1 parent f4278a5 commit 67d4521
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 4 deletions.
64 changes: 61 additions & 3 deletions src/zarr/testing/strategies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import Any
from typing import Any, Literal

import hypothesis.extra.numpy as npst
import hypothesis.strategies as st
Expand All @@ -8,9 +8,10 @@
from hypothesis.strategies import SearchStrategy

import zarr
from zarr.abc.store import RangeByteRequest
from zarr.abc.store import RangeByteRequest, Store
from zarr.core.array import Array
from zarr.core.common import ZarrFormat
from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata
from zarr.core.sync import sync
from zarr.storage import MemoryStore, StoreLike
from zarr.storage._common import _dereference_path
Expand Down Expand Up @@ -67,6 +68,11 @@ def safe_unicode_for_dtype(dtype: np.dtype[np.str_]) -> st.SearchStrategy[str]:
)


def clear_store(x: Store) -> Store:
sync(x.clear())
return x


# From https://zarr-specs.readthedocs.io/en/latest/v3/core/v3.0.html#node-names
# 1. must not be the empty string ("")
# 2. must not include the character "/"
Expand All @@ -85,12 +91,64 @@ def safe_unicode_for_dtype(dtype: np.dtype[np.str_]) -> st.SearchStrategy[str]:
# st.builds will only call a new store constructor for different keyword arguments
# i.e. stores.examples() will always return the same object per Store class.
# So we map a clear to reset the store.
stores = st.builds(MemoryStore, st.just({})).map(lambda x: sync(x.clear()))
stores = st.builds(MemoryStore, st.just({})).map(clear_store)
compressors = st.sampled_from([None, "default"])
zarr_formats: st.SearchStrategy[ZarrFormat] = st.sampled_from([2, 3])
array_shapes = npst.array_shapes(max_dims=4, min_side=0)


@st.composite # type: ignore[misc]
def dimension_names(draw: st.DrawFn, *, ndim: int | None = None) -> list[None | str] | None:
simple_text = st.text(zarr_key_chars, min_size=0)
return draw(st.none() | st.lists(st.none() | simple_text, min_size=ndim, max_size=ndim)) # type: ignore[no-any-return]


@st.composite # type: ignore[misc]
def array_metadata(
draw: st.DrawFn,
*,
array_shapes: st.SearchStrategy[tuple[int, ...]] = npst.array_shapes,
zarr_formats: st.SearchStrategy[Literal[2, 3]] = zarr_formats,
attributes: st.SearchStrategy[dict[str, Any]] = attrs,
) -> ArrayV2Metadata | ArrayV3Metadata:
from zarr.codecs.bytes import BytesCodec
from zarr.core.chunk_grids import RegularChunkGrid
from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding
from zarr.core.metadata.v3 import ArrayV3Metadata

zarr_format = draw(zarr_formats)
# separator = draw(st.sampled_from(['/', '\\']))
shape = draw(array_shapes())
ndim = len(shape)
chunk_shape = draw(array_shapes(min_dims=ndim, max_dims=ndim))
dtype = draw(v3_dtypes())
fill_value = draw(npst.from_dtype(dtype))
if zarr_format == 2:
return ArrayV2Metadata(
shape=shape,
chunks=chunk_shape,
dtype=dtype,
fill_value=fill_value,
order=draw(st.sampled_from(["C", "F"])),
attributes=draw(attributes),
dimension_separator=draw(st.sampled_from([".", "/"])),
filters=None,
compressor=None,
)
else:
return ArrayV3Metadata(
shape=shape,
data_type=dtype,
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape),
fill_value=fill_value,
attributes=draw(attributes),
dimension_names=draw(dimension_names(ndim=ndim)),
chunk_key_encoding=DefaultChunkKeyEncoding(separator="/"), # FIXME
codecs=[BytesCodec()],
storage_transformers=(),
)


@st.composite # type: ignore[misc]
def numpy_arrays(
draw: st.DrawFn,
Expand Down
24 changes: 23 additions & 1 deletion tests/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,24 @@
import pytest
from numpy.testing import assert_array_equal

from zarr.core.buffer import default_buffer_prototype

pytest.importorskip("hypothesis")

import hypothesis.extra.numpy as npst
import hypothesis.strategies as st
from hypothesis import given

from zarr.testing.strategies import arrays, basic_indices, numpy_arrays, zarr_formats
from zarr.abc.store import Store
from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata
from zarr.testing.strategies import (
array_metadata,
arrays,
basic_indices,
numpy_arrays,
stores,
zarr_formats,
)


@given(data=st.data(), zarr_format=zarr_formats)
Expand Down Expand Up @@ -47,6 +58,17 @@ def test_vindex(data: st.DataObject) -> None:
assert_array_equal(nparray[indexer], actual)


@given(store=stores, meta=array_metadata()) # type: ignore[misc]
async def test_roundtrip_array_metadata(
store: Store, meta: ArrayV2Metadata | ArrayV3Metadata
) -> None:
asdict = meta.to_buffer_dict(prototype=default_buffer_prototype())
for key, expected in asdict.items():
await store.set(f"0/{key}", expected)
actual = await store.get(f"0/{key}", prototype=default_buffer_prototype())
assert actual == expected


# @st.composite
# def advanced_indices(draw, *, shape):
# basic_idxr = draw(
Expand Down

0 comments on commit 67d4521

Please sign in to comment.