From ceb3b361fb912dc748ac871018bcf7c757e4719d Mon Sep 17 00:00:00 2001 From: Joe Hamman Date: Sat, 14 Sep 2024 09:09:47 -0700 Subject: [PATCH] test: check that store, array, and group classes are serializable (#2006) * test: check that store, array, and group classes are serializable w/ pickle and can be dependably roundtripped * raise if MemoryStore is pickled * Apply suggestions from code review Co-authored-by: Davis Bennett * fix typos * new buffer __eq__ * pickle support for zip store --------- Co-authored-by: Davis Bennett --- src/zarr/abc/store.py | 5 + src/zarr/core/buffer/core.py | 6 ++ src/zarr/store/memory.py | 16 ++- src/zarr/store/remote.py | 10 ++ src/zarr/store/zip.py | 15 ++- src/zarr/testing/store.py | 14 +++ tests/v3/test_array.py | 36 ++++++- tests/v3/test_group.py | 165 +++-------------------------- tests/v3/test_store/test_memory.py | 52 ++------- tests/v3/test_store/test_remote.py | 2 +- 10 files changed, 127 insertions(+), 194 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 70ac9adc17..95d55a2ce0 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -83,6 +83,11 @@ def _check_writable(self) -> None: if self.mode.readonly: raise ValueError("store mode does not support writing") + @abstractmethod + def __eq__(self, value: object) -> bool: + """Equality comparison.""" + ... + @abstractmethod async def get( self, diff --git a/src/zarr/core/buffer/core.py b/src/zarr/core/buffer/core.py index ff26478ca9..0c6d966db9 100644 --- a/src/zarr/core/buffer/core.py +++ b/src/zarr/core/buffer/core.py @@ -281,6 +281,12 @@ def __add__(self, other: Buffer) -> Self: """Concatenate two buffers""" ... + def __eq__(self, other: object) -> bool: + # Another Buffer class can override this to choose a more efficient path + return isinstance(other, Buffer) and np.array_equal( + self.as_numpy_array(), other.as_numpy_array() + ) + class NDBuffer: """An n-dimensional memory block diff --git a/src/zarr/store/memory.py b/src/zarr/store/memory.py index 89e7ced31e..13e289f374 100644 --- a/src/zarr/store/memory.py +++ b/src/zarr/store/memory.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from collections.abc import AsyncGenerator, MutableMapping +from typing import TYPE_CHECKING, Any from zarr.abc.store import Store from zarr.core.buffer import Buffer, gpu @@ -47,6 +48,19 @@ def __str__(self) -> str: def __repr__(self) -> str: return f"MemoryStore({str(self)!r})" + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, type(self)) + and self._store_dict == other._store_dict + and self.mode == other.mode + ) + + def __setstate__(self, state: Any) -> None: + raise NotImplementedError(f"{type(self)} cannot be pickled") + + def __getstate__(self) -> None: + raise NotImplementedError(f"{type(self)} cannot be pickled") + async def get( self, key: str, diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index 3907ac3cc2..e3e2ba3447 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -51,6 +51,7 @@ def __init__( this must not be used. """ super().__init__(mode=mode) + self._storage_options = storage_options if isinstance(url, str): self._url = url.rstrip("/") self._fs, _path = fsspec.url_to_fs(url, **storage_options) @@ -91,6 +92,15 @@ def __str__(self) -> str: def __repr__(self) -> str: return f"" + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, type(self)) + and self.path == other.path + and self.mode == other.mode + and self._url == other._url + # and self._storage_options == other._storage_options # FIXME: this isn't working for some reason + ) + async def get( self, key: str, diff --git a/src/zarr/store/zip.py b/src/zarr/store/zip.py index 15473aa674..ea31ad934a 100644 --- a/src/zarr/store/zip.py +++ b/src/zarr/store/zip.py @@ -5,7 +5,7 @@ import time import zipfile from pathlib import Path -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Any, Literal from zarr.abc.store import Store from zarr.core.buffer import Buffer, BufferPrototype @@ -68,7 +68,7 @@ def __init__( self.compression = compression self.allowZip64 = allowZip64 - async def _open(self) -> None: + def _sync_open(self) -> None: if self._is_open: raise ValueError("store is already open") @@ -83,6 +83,17 @@ async def _open(self) -> None: self._is_open = True + async def _open(self) -> None: + self._sync_open() + + def __getstate__(self) -> tuple[Path, ZipStoreAccessModeLiteral, int, bool]: + return self.path, self._zmode, self.compression, self.allowZip64 + + def __setstate__(self, state: Any) -> None: + self.path, self._zmode, self.compression, self.allowZip64 = state + self._is_open = False + self._sync_open() + def close(self) -> None: super().close() with self._lock: diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 65d7751a0d..a08b6960db 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -1,3 +1,4 @@ +import pickle from typing import Any, Generic, TypeVar import pytest @@ -48,6 +49,19 @@ def test_store_type(self, store: S) -> None: assert isinstance(store, Store) assert isinstance(store, self.store_cls) + def test_store_eq(self, store: S, store_kwargs: dict[str, Any]) -> None: + # check self equality + assert store == store + + # check store equality with same inputs + # asserting this is important for being able to compare (de)serialized stores + store2 = self.store_cls(**store_kwargs) + assert store == store2 + + def test_serizalizable_store(self, store: S) -> None: + foo = pickle.dumps(store) + assert pickle.loads(foo) == store + def test_store_mode(self, store: S, store_kwargs: dict[str, Any]) -> None: assert store.mode == AccessMode.from_literal("r+") assert not store.mode.readonly diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py index cd20ab6e58..11be51682c 100644 --- a/tests/v3/test_array.py +++ b/tests/v3/test_array.py @@ -1,9 +1,10 @@ +import pickle from typing import Literal import numpy as np import pytest -from zarr import Array, Group +from zarr import Array, AsyncArray, Group from zarr.core.common import ZarrFormat from zarr.errors import ContainsArrayError, ContainsGroupError from zarr.store import LocalStore, MemoryStore @@ -135,3 +136,36 @@ def test_array_v3_fill_value(store: MemoryStore, fill_value: int, dtype_str: str assert arr.fill_value == np.dtype(dtype_str).type(fill_value) assert arr.fill_value.dtype == arr.dtype + + +@pytest.mark.parametrize("store", ("local",), indirect=["store"]) +@pytest.mark.parametrize("zarr_format", (2, 3)) +async def test_serializable_async_array( + store: LocalStore | MemoryStore, zarr_format: ZarrFormat +) -> None: + expected = await AsyncArray.create( + store=store, shape=(100,), chunks=(10,), zarr_format=zarr_format, dtype="i4" + ) + # await expected.setitems(list(range(100))) + + p = pickle.dumps(expected) + actual = pickle.loads(p) + + assert actual == expected + # np.testing.assert_array_equal(await actual.getitem(slice(None)), await expected.getitem(slice(None))) + # TODO: uncomment the parts of this test that will be impacted by the config/prototype changes in flight + + +@pytest.mark.parametrize("store", ("local",), indirect=["store"]) +@pytest.mark.parametrize("zarr_format", (2, 3)) +def test_serializable_sync_array(store: LocalStore, zarr_format: ZarrFormat) -> None: + expected = Array.create( + store=store, shape=(100,), chunks=(10,), zarr_format=zarr_format, dtype="i4" + ) + expected[:] = list(range(100)) + + p = pickle.dumps(expected) + actual = pickle.loads(p) + + assert actual == expected + np.testing.assert_array_equal(actual[:], expected[:]) diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index a62f367351..94b839a186 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -1,20 +1,19 @@ from __future__ import annotations +import pickle from typing import TYPE_CHECKING, Any, Literal, cast import numpy as np import pytest -import zarr.api.asynchronous from zarr import Array, AsyncArray, AsyncGroup, Group from zarr.abc.store import Store -from zarr.api.synchronous import open_group from zarr.core.buffer import default_buffer_prototype from zarr.core.common import JSON, ZarrFormat from zarr.core.group import GroupMetadata from zarr.core.sync import sync from zarr.errors import ContainsArrayError, ContainsGroupError -from zarr.store import LocalStore, MemoryStore, StorePath +from zarr.store import LocalStore, StorePath from zarr.store.common import make_store_path from .conftest import parse_store @@ -681,152 +680,22 @@ async def test_asyncgroup_update_attributes(store: Store, zarr_format: ZarrForma assert agroup_new_attributes.attrs == attributes_new -async def test_group_members_async(store: LocalStore | MemoryStore) -> None: - group = AsyncGroup( - GroupMetadata(), - store_path=StorePath(store=store, path="root"), - ) - a0 = await group.create_array("a0", shape=(1,)) - g0 = await group.create_group("g0") - a1 = await g0.create_array("a1", shape=(1,)) - g1 = await g0.create_group("g1") - a2 = await g1.create_array("a2", shape=(1,)) - g2 = await g1.create_group("g2") - - # immediate children - children = sorted([x async for x in group.members()], key=lambda x: x[0]) - assert children == [ - ("a0", a0), - ("g0", g0), - ] - - nmembers = await group.nmembers() - assert nmembers == 2 - - # partial - children = sorted([x async for x in group.members(max_depth=1)], key=lambda x: x[0]) - expected = [ - ("a0", a0), - ("g0", g0), - ("g0/a1", a1), - ("g0/g1", g1), - ] - assert children == expected - nmembers = await group.nmembers(max_depth=1) - assert nmembers == 4 - - # all children - all_children = sorted([x async for x in group.members(max_depth=None)], key=lambda x: x[0]) - expected = [ - ("a0", a0), - ("g0", g0), - ("g0/a1", a1), - ("g0/g1", g1), - ("g0/g1/a2", a2), - ("g0/g1/g2", g2), - ] - assert all_children == expected - - nmembers = await group.nmembers(max_depth=None) - assert nmembers == 6 - - with pytest.raises(ValueError, match="max_depth"): - [x async for x in group.members(max_depth=-1)] - - -async def test_require_group(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: - root = await AsyncGroup.create(store=store, zarr_format=zarr_format) - - # create foo group - _ = await root.create_group("foo", attributes={"foo": 100}) - - # test that we can get the group using require_group - foo_group = await root.require_group("foo") - assert foo_group.attrs == {"foo": 100} - - # test that we can get the group using require_group and overwrite=True - foo_group = await root.require_group("foo", overwrite=True) - - _ = await foo_group.create_array( - "bar", shape=(10,), dtype="uint8", chunk_shape=(2,), attributes={"foo": 100} +@pytest.mark.parametrize("store", ("local",), indirect=["store"]) +@pytest.mark.parametrize("zarr_format", (2, 3)) +async def test_serializable_async_group(store: LocalStore, zarr_format: ZarrFormat) -> None: + expected = await AsyncGroup.create( + store=store, attributes={"foo": 999}, zarr_format=zarr_format ) + p = pickle.dumps(expected) + actual = pickle.loads(p) + assert actual == expected - # test that overwriting a group w/ children fails - # TODO: figure out why ensure_no_existing_node is not catching the foo.bar array - # - # with pytest.raises(ContainsArrayError): - # await root.require_group("foo", overwrite=True) - - # test that requiring a group where an array is fails - with pytest.raises(TypeError): - await foo_group.require_group("bar") - - -async def test_require_groups(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: - root = await AsyncGroup.create(store=store, zarr_format=zarr_format) - # create foo group - _ = await root.create_group("foo", attributes={"foo": 100}) - # create bar group - _ = await root.create_group("bar", attributes={"bar": 200}) - - foo_group, bar_group = await root.require_groups("foo", "bar") - assert foo_group.attrs == {"foo": 100} - assert bar_group.attrs == {"bar": 200} - - # get a mix of existing and new groups - foo_group, spam_group = await root.require_groups("foo", "spam") - assert foo_group.attrs == {"foo": 100} - assert spam_group.attrs == {} - - # no names - no_group = await root.require_groups() - assert no_group == () - - -async def test_create_dataset(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: - root = await AsyncGroup.create(store=store, zarr_format=zarr_format) - with pytest.warns(DeprecationWarning): - foo = await root.create_dataset("foo", shape=(10,), dtype="uint8") - assert foo.shape == (10,) - - with pytest.raises(ContainsArrayError), pytest.warns(DeprecationWarning): - await root.create_dataset("foo", shape=(100,), dtype="int8") - - _ = await root.create_group("bar") - with pytest.raises(ContainsGroupError), pytest.warns(DeprecationWarning): - await root.create_dataset("bar", shape=(100,), dtype="int8") - - -async def test_require_array(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: - root = await AsyncGroup.create(store=store, zarr_format=zarr_format) - foo1 = await root.require_array("foo", shape=(10,), dtype="i8", attributes={"foo": 101}) - assert foo1.attrs == {"foo": 101} - foo2 = await root.require_array("foo", shape=(10,), dtype="i8") - assert foo2.attrs == {"foo": 101} - - # exact = False - _ = await root.require_array("foo", shape=10, dtype="f8") - - # errors w/ exact True - with pytest.raises(TypeError, match="Incompatible dtype"): - await root.require_array("foo", shape=(10,), dtype="f8", exact=True) - - with pytest.raises(TypeError, match="Incompatible shape"): - await root.require_array("foo", shape=(100, 100), dtype="i8") - - with pytest.raises(TypeError, match="Incompatible dtype"): - await root.require_array("foo", shape=(10,), dtype="f4") - - _ = await root.create_group("bar") - with pytest.raises(TypeError, match="Incompatible object"): - await root.require_array("bar", shape=(10,), dtype="int8") - - -async def test_open_mutable_mapping(): - group = await zarr.api.asynchronous.open_group(store={}, mode="w") - assert isinstance(group.store_path.store, MemoryStore) +@pytest.mark.parametrize("store", ("local",), indirect=["store"]) +@pytest.mark.parametrize("zarr_format", (2, 3)) +def test_serializable_sync_group(store: LocalStore, zarr_format: ZarrFormat) -> None: + expected = Group.create(store=store, attributes={"foo": 999}, zarr_format=zarr_format) + p = pickle.dumps(expected) + actual = pickle.loads(p) -def test_open_mutable_mapping_sync(): - group = open_group(store={}, mode="w") - assert isinstance(group.store_path.store, MemoryStore) + assert actual == expected diff --git a/tests/v3/test_store/test_memory.py b/tests/v3/test_store/test_memory.py index 13aaa20bda..04d17eb240 100644 --- a/tests/v3/test_store/test_memory.py +++ b/tests/v3/test_store/test_memory.py @@ -1,11 +1,12 @@ from __future__ import annotations +import pickle + import pytest -from zarr.core.buffer import Buffer, cpu, gpu -from zarr.store.memory import GpuMemoryStore, MemoryStore +from zarr.core.buffer import Buffer, cpu +from zarr.store.memory import MemoryStore from zarr.testing.store import StoreTests -from zarr.testing.utils import gpu_test class TestMemoryStore(StoreTests[MemoryStore, cpu.Buffer]): @@ -46,43 +47,12 @@ def test_store_supports_partial_writes(self, store: MemoryStore) -> None: def test_list_prefix(self, store: MemoryStore) -> None: assert True + def test_serizalizable_store(self, store: MemoryStore) -> None: + with pytest.raises(NotImplementedError): + store.__getstate__() -@gpu_test -class TestGpuMemoryStore(StoreTests[GpuMemoryStore, gpu.Buffer]): - store_cls = GpuMemoryStore - buffer_cls = gpu.Buffer - - def set(self, store: GpuMemoryStore, key: str, value: Buffer) -> None: - store._store_dict[key] = value - - def get(self, store: MemoryStore, key: str) -> Buffer: - return store._store_dict[key] - - @pytest.fixture(scope="function", params=[None, {}]) - def store_kwargs(self, request) -> dict[str, str | None | dict[str, Buffer]]: - return {"store_dict": request.param, "mode": "r+"} - - @pytest.fixture(scope="function") - def store(self, store_kwargs: str | None | dict[str, gpu.Buffer]) -> GpuMemoryStore: - return self.store_cls(**store_kwargs) - - def test_store_repr(self, store: GpuMemoryStore) -> None: - assert str(store) == f"gpumemory://{id(store._store_dict)}" - - def test_store_supports_writes(self, store: GpuMemoryStore) -> None: - assert store.supports_writes - - def test_store_supports_listing(self, store: GpuMemoryStore) -> None: - assert store.supports_listing - - def test_store_supports_partial_writes(self, store: GpuMemoryStore) -> None: - assert store.supports_partial_writes - - def test_list_prefix(self, store: GpuMemoryStore) -> None: - assert True - + with pytest.raises(NotImplementedError): + store.__setstate__({}) -def test_uses_dict() -> None: - store_dict = {} - store = MemoryStore(store_dict) - assert store._store_dict is store_dict + with pytest.raises(NotImplementedError): + pickle.dumps(store) diff --git a/tests/v3/test_store/test_remote.py b/tests/v3/test_store/test_remote.py index e400857c45..afa991209f 100644 --- a/tests/v3/test_store/test_remote.py +++ b/tests/v3/test_store/test_remote.py @@ -109,7 +109,7 @@ def store_kwargs(self, request) -> dict[str, str | bool]: anon = False mode = "r+" if request.param == "use_upath": - return {"mode": mode, "url": UPath(url, endpoint_url=endpoint_url, anon=anon)} + return {"url": UPath(url, endpoint_url=endpoint_url, anon=anon), "mode": mode} elif request.param == "use_str": return {"url": url, "mode": mode, "anon": anon, "endpoint_url": endpoint_url}