diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 79c68418da..99a69dc541 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ default_language_version: python: python3 repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.5 + rev: v0.6.7 hooks: - id: ruff args: ["--fix", "--show-fixes"] diff --git a/pyproject.toml b/pyproject.toml index 886cd5a0bc..63a58ac795 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,8 @@ test = [ "flask", "requests", "mypy", - "hypothesis" + "hypothesis", + "universal-pathlib", ] jupyter = [ @@ -273,6 +274,7 @@ filterwarnings = [ "ignore:PY_SSIZE_T_CLEAN will be required.*:DeprecationWarning", "ignore:The loop argument is deprecated since Python 3.8.*:DeprecationWarning", "ignore:Creating a zarr.buffer.gpu.*:UserWarning", + "ignore:Duplicate name:UserWarning", # from ZipFile ] markers = [ "gpu: mark a test as requiring CuPy and GPU" diff --git a/src/zarr/_compat.py b/src/zarr/_compat.py new file mode 100644 index 0000000000..52d96005cc --- /dev/null +++ b/src/zarr/_compat.py @@ -0,0 +1,68 @@ +import warnings +from collections.abc import Callable +from functools import wraps +from inspect import Parameter, signature +from typing import Any, TypeVar + +T = TypeVar("T") + +# Based off https://github.com/scikit-learn/scikit-learn/blob/e87b32a81c70abed8f2e97483758eb64df8255e9/sklearn/utils/validation.py#L63 + + +def _deprecate_positional_args( + func: Callable[..., T] | None = None, *, version: str = "3.1.0" +) -> Callable[..., T]: + """Decorator for methods that issues warnings for positional arguments. + + Using the keyword-only argument syntax in pep 3102, arguments after the + * will issue a warning when passed as a positional argument. + + Parameters + ---------- + func : callable, default=None + Function to check arguments on. + version : callable, default="3.1.0" + The version when positional arguments will result in error. + """ + + def _inner_deprecate_positional_args(f: Callable[..., T]) -> Callable[..., T]: + sig = signature(f) + kwonly_args = [] + all_args = [] + + for name, param in sig.parameters.items(): + if param.kind == Parameter.POSITIONAL_OR_KEYWORD: + all_args.append(name) + elif param.kind == Parameter.KEYWORD_ONLY: + kwonly_args.append(name) + + @wraps(f) + def inner_f(*args: Any, **kwargs: Any) -> T: + extra_args = len(args) - len(all_args) + if extra_args <= 0: + return f(*args, **kwargs) + + # extra_args > 0 + args_msg = [ + f"{name}={arg}" + for name, arg in zip(kwonly_args[:extra_args], args[-extra_args:], strict=False) + ] + formatted_args_msg = ", ".join(args_msg) + warnings.warn( + ( + f"Pass {formatted_args_msg} as keyword args. From version " + f"{version} passing these as positional arguments " + "will result in an error" + ), + FutureWarning, + stacklevel=2, + ) + kwargs.update(zip(sig.parameters, args, strict=False)) + return f(**kwargs) + + return inner_f + + if func is not None: + return _inner_deprecate_positional_args(func) + + return _inner_deprecate_positional_args # type: ignore[return-value] diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 95d55a2ce0..f95ba34efd 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator +from asyncio import gather +from collections.abc import AsyncGenerator, Iterable from typing import Any, NamedTuple, Protocol, runtime_checkable from typing_extensions import Self @@ -158,6 +159,13 @@ async def set(self, key: str, value: Buffer) -> None: """ ... + async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: + """ + Insert multiple (key, value) pairs into storage. + """ + await gather(*(self.set(key, value) for key, value in values)) + return None + @property @abstractmethod def supports_deletes(self) -> bool: @@ -211,7 +219,9 @@ def list(self) -> AsyncGenerator[str, None]: @abstractmethod def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: - """Retrieve all keys in the store with a given prefix. + """ + Retrieve all keys in the store that begin with a given prefix. Keys are returned with the + common leading prefix removed. Parameters ---------- diff --git a/src/zarr/api/asynchronous.py b/src/zarr/api/asynchronous.py index 8a1b0c5f36..5fbb38c5e7 100644 --- a/src/zarr/api/asynchronous.py +++ b/src/zarr/api/asynchronous.py @@ -194,6 +194,7 @@ async def open( zarr_version: ZarrFormat | None = None, # deprecated zarr_format: ZarrFormat | None = None, path: str | None = None, + storage_options: dict[str, Any] | None = None, **kwargs: Any, # TODO: type kwargs as valid args to open_array ) -> AsyncArray | AsyncGroup: """Convenience function to open a group or array using file-mode-like semantics. @@ -211,6 +212,9 @@ async def open( The zarr format to use when saving. path : str or None, optional The path within the store to open. + storage_options : dict + If using an fsspec URL to create the store, these will be passed to + the backend implementation. Ignored otherwise. **kwargs Additional parameters are passed through to :func:`zarr.creation.open_array` or :func:`zarr.hierarchy.open_group`. @@ -221,7 +225,7 @@ async def open( Return type depends on what exists in the given store. """ zarr_format = _handle_zarr_version_or_format(zarr_version=zarr_version, zarr_format=zarr_format) - store_path = await make_store_path(store, mode=mode) + store_path = await make_store_path(store, mode=mode, storage_options=storage_options) if path is not None: store_path = store_path / path @@ -276,6 +280,7 @@ async def save_array( zarr_version: ZarrFormat | None = None, # deprecated zarr_format: ZarrFormat | None = None, path: str | None = None, + storage_options: dict[str, Any] | None = None, **kwargs: Any, # TODO: type kwargs as valid args to create ) -> None: """Convenience function to save a NumPy array to the local file system, following a @@ -291,6 +296,9 @@ async def save_array( The zarr format to use when saving. path : str or None, optional The path within the store where the array will be saved. + storage_options : dict + If using an fsspec URL to create the store, these will be passed to + the backend implementation. Ignored otherwise. kwargs Passed through to :func:`create`, e.g., compressor. """ @@ -299,7 +307,7 @@ async def save_array( or _default_zarr_version() ) - store_path = await make_store_path(store, mode="w") + store_path = await make_store_path(store, mode="w", storage_options=storage_options) if path is not None: store_path = store_path / path new = await AsyncArray.create( @@ -319,6 +327,7 @@ async def save_group( zarr_version: ZarrFormat | None = None, # deprecated zarr_format: ZarrFormat | None = None, path: str | None = None, + storage_options: dict[str, Any] | None = None, **kwargs: NDArrayLike, ) -> None: """Convenience function to save several NumPy arrays to the local file system, following a @@ -334,11 +343,17 @@ async def save_group( The zarr format to use when saving. path : str or None, optional Path within the store where the group will be saved. + storage_options : dict + If using an fsspec URL to create the store, these will be passed to + the backend implementation. Ignored otherwise. kwargs NumPy arrays with data to save. """ zarr_format = ( - _handle_zarr_version_or_format(zarr_version=zarr_version, zarr_format=zarr_format) + _handle_zarr_version_or_format( + zarr_version=zarr_version, + zarr_format=zarr_format, + ) or _default_zarr_version() ) @@ -346,10 +361,22 @@ async def save_group( raise ValueError("at least one array must be provided") aws = [] for i, arr in enumerate(args): - aws.append(save_array(store, arr, zarr_format=zarr_format, path=f"{path}/arr_{i}")) + aws.append( + save_array( + store, + arr, + zarr_format=zarr_format, + path=f"{path}/arr_{i}", + storage_options=storage_options, + ) + ) for k, arr in kwargs.items(): _path = f"{path}/{k}" if path is not None else k - aws.append(save_array(store, arr, zarr_format=zarr_format, path=_path)) + aws.append( + save_array( + store, arr, zarr_format=zarr_format, path=_path, storage_options=storage_options + ) + ) await asyncio.gather(*aws) @@ -418,6 +445,7 @@ async def group( zarr_format: ZarrFormat | None = None, meta_array: Any | None = None, # not used attributes: dict[str, JSON] | None = None, + storage_options: dict[str, Any] | None = None, ) -> AsyncGroup: """Create a group. @@ -444,6 +472,9 @@ async def group( to users. Use `numpy.empty(())` by default. zarr_format : {2, 3, None}, optional The zarr format to use when saving. + storage_options : dict + If using an fsspec URL to create the store, these will be passed to + the backend implementation. Ignored otherwise. Returns ------- @@ -453,7 +484,7 @@ async def group( zarr_format = _handle_zarr_version_or_format(zarr_version=zarr_version, zarr_format=zarr_format) - store_path = await make_store_path(store) + store_path = await make_store_path(store, storage_options=storage_options) if path is not None: store_path = store_path / path @@ -472,7 +503,7 @@ async def group( try: return await AsyncGroup.open(store=store_path, zarr_format=zarr_format) except (KeyError, FileNotFoundError): - return await AsyncGroup.create( + return await AsyncGroup.from_store( store=store_path, zarr_format=zarr_format or _default_zarr_version(), exists_ok=overwrite, @@ -481,14 +512,14 @@ async def group( async def open_group( - *, # Note: this is a change from v2 store: StoreLike | None = None, + *, # Note: this is a change from v2 mode: AccessModeLiteral | None = None, cache_attrs: bool | None = None, # not used, default changed synchronizer: Any = None, # not used path: str | None = None, chunk_store: StoreLike | None = None, # not used - storage_options: dict[str, Any] | None = None, # not used + storage_options: dict[str, Any] | None = None, zarr_version: ZarrFormat | None = None, # deprecated zarr_format: ZarrFormat | None = None, meta_array: Any | None = None, # not used @@ -548,10 +579,8 @@ async def open_group( warnings.warn("meta_array is not yet implemented", RuntimeWarning, stacklevel=2) if chunk_store is not None: warnings.warn("chunk_store is not yet implemented", RuntimeWarning, stacklevel=2) - if storage_options is not None: - warnings.warn("storage_options is not yet implemented", RuntimeWarning, stacklevel=2) - store_path = await make_store_path(store, mode=mode) + store_path = await make_store_path(store, mode=mode, storage_options=storage_options) if path is not None: store_path = store_path / path @@ -561,7 +590,7 @@ async def open_group( try: return await AsyncGroup.open(store_path, zarr_format=zarr_format) except (KeyError, FileNotFoundError): - return await AsyncGroup.create( + return await AsyncGroup.from_store( store_path, zarr_format=zarr_format or _default_zarr_version(), exists_ok=True, @@ -575,7 +604,7 @@ async def create( chunks: ChunkCoords | None = None, # TODO: v2 allowed chunks=True dtype: npt.DTypeLike | None = None, compressor: dict[str, JSON] | None = None, # TODO: default and type change - fill_value: Any = 0, # TODO: need type + fill_value: Any | None = 0, # TODO: need type order: MemoryOrder | None = None, # TODO: default change store: str | StoreLike | None = None, synchronizer: Any | None = None, @@ -603,6 +632,7 @@ async def create( ) = None, codecs: Iterable[Codec | dict[str, JSON]] | None = None, dimension_names: Iterable[str] | None = None, + storage_options: dict[str, Any] | None = None, **kwargs: Any, ) -> AsyncArray: """Create an array. @@ -674,6 +704,9 @@ async def create( to users. Use `numpy.empty(())` by default. .. versionadded:: 2.13 + storage_options : dict + If using an fsspec URL to create the store, these will be passed to + the backend implementation. Ignored otherwise. Returns ------- @@ -725,7 +758,7 @@ async def create( warnings.warn("meta_array is not yet implemented", RuntimeWarning, stacklevel=2) mode = kwargs.pop("mode", cast(AccessModeLiteral, "r" if read_only else "w")) - store_path = await make_store_path(store, mode=mode) + store_path = await make_store_path(store, mode=mode, storage_options=storage_options) if path is not None: store_path = store_path / path @@ -827,7 +860,7 @@ async def full_like(a: ArrayLike, **kwargs: Any) -> AsyncArray: """ like_kwargs = _like_args(a, kwargs) if isinstance(a, AsyncArray): - kwargs.setdefault("fill_value", a.metadata.fill_value) + like_kwargs.setdefault("fill_value", a.metadata.fill_value) return await full(**like_kwargs) @@ -875,6 +908,7 @@ async def open_array( zarr_version: ZarrFormat | None = None, # deprecated zarr_format: ZarrFormat | None = None, path: PathLike | None = None, + storage_options: dict[str, Any] | None = None, **kwargs: Any, # TODO: type kwargs as valid args to save ) -> AsyncArray: """Open an array using file-mode-like semantics. @@ -887,6 +921,9 @@ async def open_array( The zarr format to use when saving. path : string, optional Path in store to array. + storage_options : dict + If using an fsspec URL to create the store, these will be passed to + the backend implementation. Ignored otherwise. **kwargs Any keyword arguments to pass to the array constructor. @@ -896,7 +933,7 @@ async def open_array( The opened array. """ - store_path = await make_store_path(store) + store_path = await make_store_path(store, storage_options=storage_options) if path is not None: store_path = store_path / path diff --git a/src/zarr/api/synchronous.py b/src/zarr/api/synchronous.py index 93a33b8d3f..bc4a7bfafd 100644 --- a/src/zarr/api/synchronous.py +++ b/src/zarr/api/synchronous.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any import zarr.api.asynchronous as async_api +from zarr._compat import _deprecate_positional_args from zarr.core.array import Array, AsyncArray from zarr.core.group import Group from zarr.core.sync import sync @@ -63,9 +64,10 @@ def load( return sync(async_api.load(store=store, zarr_version=zarr_version, path=path)) +@_deprecate_positional_args def open( - *, store: StoreLike | None = None, + *, mode: AccessModeLiteral | None = None, # type and value changed zarr_version: ZarrFormat | None = None, # deprecated zarr_format: ZarrFormat | None = None, @@ -107,6 +109,7 @@ def save( ) +@_deprecate_positional_args def save_array( store: StoreLike, arr: NDArrayLike, @@ -134,6 +137,7 @@ def save_group( zarr_version: ZarrFormat | None = None, # deprecated zarr_format: ZarrFormat | None = None, path: str | None = None, + storage_options: dict[str, Any] | None = None, **kwargs: NDArrayLike, ) -> None: return sync( @@ -143,6 +147,7 @@ def save_group( zarr_version=zarr_version, zarr_format=zarr_format, path=path, + storage_options=storage_options, **kwargs, ) ) @@ -157,9 +162,10 @@ def array(data: NDArrayLike, **kwargs: Any) -> Array: return Array(sync(async_api.array(data=data, **kwargs))) +@_deprecate_positional_args def group( - *, # Note: this is a change from v2 store: StoreLike | None = None, + *, # Note: this is a change from v2 overwrite: bool = False, chunk_store: StoreLike | None = None, # not used in async_api cache_attrs: bool | None = None, # default changed, not used in async_api @@ -188,9 +194,10 @@ def group( ) +@_deprecate_positional_args def open_group( - *, # Note: this is a change from v2 store: StoreLike | None = None, + *, # Note: this is a change from v2 mode: AccessModeLiteral | None = None, # not used in async api cache_attrs: bool | None = None, # default changed, not used in async api synchronizer: Any = None, # not used in async api diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 3a455b239f..b825ca4ca1 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -8,6 +8,8 @@ import numpy as np import numpy.typing as npt +from zarr._compat import _deprecate_positional_args +from zarr.abc.codec import Codec, CodecPipeline from zarr.abc.store import set_or_delete from zarr.codecs import BytesCodec from zarr.codecs._v2 import V2Compressor, V2Filters @@ -313,7 +315,7 @@ async def _create_v2( chunks=chunks, order=order, dimension_separator=dimension_separator, - fill_value=fill_value, + fill_value=0 if fill_value is None else fill_value, compressor=compressor, filters=filters, attributes=attributes, @@ -621,6 +623,7 @@ class Array: _async_array: AsyncArray @classmethod + @_deprecate_positional_args def create( cls, store: StoreLike, @@ -1016,6 +1019,7 @@ def __setitem__(self, selection: Selection, value: npt.ArrayLike) -> None: else: self.set_basic_selection(cast(BasicSelection, pure_selection), value, fields=fields) + @_deprecate_positional_args def get_basic_selection( self, selection: BasicSelection = Ellipsis, @@ -1139,6 +1143,7 @@ def get_basic_selection( ) ) + @_deprecate_positional_args def set_basic_selection( self, selection: BasicSelection, @@ -1234,6 +1239,7 @@ def set_basic_selection( indexer = BasicIndexer(selection, self.shape, self.metadata.chunk_grid) sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) + @_deprecate_positional_args def get_orthogonal_selection( self, selection: OrthogonalSelection, @@ -1358,6 +1364,7 @@ def get_orthogonal_selection( ) ) + @_deprecate_positional_args def set_orthogonal_selection( self, selection: OrthogonalSelection, @@ -1468,6 +1475,7 @@ def set_orthogonal_selection( self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype) ) + @_deprecate_positional_args def get_mask_selection( self, mask: MaskSelection, @@ -1550,6 +1558,7 @@ def get_mask_selection( ) ) + @_deprecate_positional_args def set_mask_selection( self, mask: MaskSelection, @@ -1628,6 +1637,7 @@ def set_mask_selection( indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid) sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) + @_deprecate_positional_args def get_coordinate_selection( self, selection: CoordinateSelection, @@ -1717,6 +1727,7 @@ def get_coordinate_selection( out_array = np.array(out_array).reshape(indexer.sel_shape) return out_array + @_deprecate_positional_args def set_coordinate_selection( self, selection: CoordinateSelection, @@ -1806,6 +1817,7 @@ def set_coordinate_selection( sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) + @_deprecate_positional_args def get_block_selection( self, selection: BasicSelection, @@ -1904,6 +1916,7 @@ def get_block_selection( ) ) + @_deprecate_positional_args def set_block_selection( self, selection: BasicSelection, diff --git a/src/zarr/core/array_spec.py b/src/zarr/core/array_spec.py index e64a962bc3..1a251a0a4b 100644 --- a/src/zarr/core/array_spec.py +++ b/src/zarr/core/array_spec.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal -from zarr.core.common import parse_dtype, parse_fill_value, parse_order, parse_shapelike +from zarr.core.common import parse_fill_value, parse_order, parse_shapelike if TYPE_CHECKING: import numpy as np @@ -29,12 +29,11 @@ def __init__( prototype: BufferPrototype, ) -> None: shape_parsed = parse_shapelike(shape) - dtype_parsed = parse_dtype(dtype) fill_value_parsed = parse_fill_value(fill_value) order_parsed = parse_order(order) object.__setattr__(self, "shape", shape_parsed) - object.__setattr__(self, "dtype", dtype_parsed) + object.__setattr__(self, "dtype", dtype) object.__setattr__(self, "fill_value", fill_value_parsed) object.__setattr__(self, "order", order_parsed) object.__setattr__(self, "prototype", prototype) diff --git a/src/zarr/core/attributes.py b/src/zarr/core/attributes.py index 09677f7bdc..62ff5fc935 100644 --- a/src/zarr/core/attributes.py +++ b/src/zarr/core/attributes.py @@ -35,3 +35,19 @@ def __iter__(self) -> Iterator[str]: def __len__(self) -> int: return len(self._obj.metadata.attributes) + + def put(self, d: dict[str, JSON]) -> None: + """ + Overwrite all attributes with the values from `d`. + + Equivalent to the following pseudo-code, but performed atomically. + + .. code-block:: python + + >>> attrs = {"a": 1, "b": 2} + >>> attrs.clear() + >>> attrs.update({"a": 3", "c": 4}) + >>> attrs + {'a': 3, 'c': 4} + """ + self._obj = self._obj.update_attributes(d) diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index 906467005f..6847bd419f 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -19,14 +19,13 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Iterator -import numpy as np -import numpy.typing as npt ZARR_JSON = "zarr.json" ZARRAY_JSON = ".zarray" ZGROUP_JSON = ".zgroup" ZATTRS_JSON = ".zattrs" +ByteRangeRequest = tuple[int | None, int | None] BytesLike = bytes | bytearray | memoryview ShapeLike = tuple[int, ...] | int ChunkCoords = tuple[int, ...] @@ -154,11 +153,6 @@ def parse_shapelike(data: int | Iterable[int]) -> tuple[int, ...]: return data_tuple -def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]: - # todo: real validation - return np.dtype(data) - - def parse_fill_value(data: Any) -> Any: # todo: real validation return data diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 40815b96c8..7c56707a4f 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -10,8 +10,9 @@ import numpy.typing as npt from typing_extensions import deprecated +import zarr.api.asynchronous as async_api from zarr.abc.metadata import Metadata -from zarr.abc.store import set_or_delete +from zarr.abc.store import Store, set_or_delete from zarr.core.array import Array, AsyncArray from zarr.core.attributes import Attributes from zarr.core.buffer import default_buffer_prototype @@ -32,7 +33,7 @@ from zarr.store.common import ensure_no_existing_node if TYPE_CHECKING: - from collections.abc import AsyncGenerator, Iterable, Iterator + from collections.abc import AsyncGenerator, Generator, Iterable, Iterator from typing import Any from zarr.abc.codec import Codec @@ -125,7 +126,7 @@ class AsyncGroup: store_path: StorePath @classmethod - async def create( + async def from_store( cls, store: StoreLike, *, @@ -311,6 +312,21 @@ def attrs(self) -> dict[str, Any]: def info(self) -> None: raise NotImplementedError + @property + def store(self) -> Store: + return self.store_path.store + + @property + def read_only(self) -> bool: + # Backwards compatibility for 2.x + return self.store_path.store.mode.readonly + + @property + def synchronizer(self) -> None: + # Backwards compatibility for 2.x + # Not implemented in 3.x yet. + return None + async def create_group( self, name: str, @@ -319,7 +335,7 @@ async def create_group( attributes: dict[str, Any] | None = None, ) -> AsyncGroup: attributes = attributes or {} - return await type(self).create( + return await type(self).from_store( self.store_path / name, attributes=attributes, exists_ok=exists_ok, @@ -677,56 +693,70 @@ async def contains(self, member: str) -> bool: else: return True - # todo: decide if this method should be separate from `groups` - async def group_keys(self) -> AsyncGenerator[str, None]: - async for key, value in self.members(): + async def groups(self) -> AsyncGenerator[tuple[str, AsyncGroup], None]: + async for name, value in self.members(): if isinstance(value, AsyncGroup): - yield key + yield name, value - # todo: decide if this method should be separate from `group_keys` - async def groups(self) -> AsyncGenerator[AsyncGroup, None]: - async for _, value in self.members(): - if isinstance(value, AsyncGroup): - yield value + async def group_keys(self) -> AsyncGenerator[str, None]: + async for key, _ in self.groups(): + yield key - # todo: decide if this method should be separate from `arrays` - async def array_keys(self) -> AsyncGenerator[str, None]: + async def group_values(self) -> AsyncGenerator[AsyncGroup, None]: + async for _, group in self.groups(): + yield group + + async def arrays(self) -> AsyncGenerator[tuple[str, AsyncArray], None]: async for key, value in self.members(): if isinstance(value, AsyncArray): - yield key + yield key, value - # todo: decide if this method should be separate from `array_keys` - async def arrays(self) -> AsyncGenerator[AsyncArray, None]: - async for _, value in self.members(): - if isinstance(value, AsyncArray): - yield value + async def array_keys(self) -> AsyncGenerator[str, None]: + async for key, _ in self.arrays(): + yield key + + async def array_values(self) -> AsyncGenerator[AsyncArray, None]: + async for _, array in self.arrays(): + yield array async def tree(self, expand: bool = False, level: int | None = None) -> Any: raise NotImplementedError - async def empty(self, **kwargs: Any) -> AsyncArray: - raise NotImplementedError + async def empty(self, *, name: str, shape: ChunkCoords, **kwargs: Any) -> AsyncArray: + return await async_api.empty(shape=shape, store=self.store_path, path=name, **kwargs) - async def zeros(self, **kwargs: Any) -> AsyncArray: - raise NotImplementedError + async def zeros(self, *, name: str, shape: ChunkCoords, **kwargs: Any) -> AsyncArray: + return await async_api.zeros(shape=shape, store=self.store_path, path=name, **kwargs) - async def ones(self, **kwargs: Any) -> AsyncArray: - raise NotImplementedError + async def ones(self, *, name: str, shape: ChunkCoords, **kwargs: Any) -> AsyncArray: + return await async_api.ones(shape=shape, store=self.store_path, path=name, **kwargs) - async def full(self, **kwargs: Any) -> AsyncArray: - raise NotImplementedError + async def full( + self, *, name: str, shape: ChunkCoords, fill_value: Any | None, **kwargs: Any + ) -> AsyncArray: + return await async_api.full( + shape=shape, fill_value=fill_value, store=self.store_path, path=name, **kwargs + ) - async def empty_like(self, prototype: AsyncArray, **kwargs: Any) -> AsyncArray: - raise NotImplementedError + async def empty_like( + self, *, name: str, prototype: async_api.ArrayLike, **kwargs: Any + ) -> AsyncArray: + return await async_api.empty_like(a=prototype, store=self.store_path, path=name, **kwargs) - async def zeros_like(self, prototype: AsyncArray, **kwargs: Any) -> AsyncArray: - raise NotImplementedError + async def zeros_like( + self, *, name: str, prototype: async_api.ArrayLike, **kwargs: Any + ) -> AsyncArray: + return await async_api.zeros_like(a=prototype, store=self.store_path, path=name, **kwargs) - async def ones_like(self, prototype: AsyncArray, **kwargs: Any) -> AsyncArray: - raise NotImplementedError + async def ones_like( + self, *, name: str, prototype: async_api.ArrayLike, **kwargs: Any + ) -> AsyncArray: + return await async_api.ones_like(a=prototype, store=self.store_path, path=name, **kwargs) - async def full_like(self, prototype: AsyncArray, **kwargs: Any) -> AsyncArray: - raise NotImplementedError + async def full_like( + self, *, name: str, prototype: async_api.ArrayLike, **kwargs: Any + ) -> AsyncArray: + return await async_api.full_like(a=prototype, store=self.store_path, path=name, **kwargs) async def move(self, source: str, dest: str) -> None: raise NotImplementedError @@ -737,7 +767,7 @@ class Group(SyncMixin): _async_group: AsyncGroup @classmethod - def create( + def from_store( cls, store: StoreLike, *, @@ -747,7 +777,7 @@ def create( ) -> Group: attributes = attributes or {} obj = sync( - AsyncGroup.create( + AsyncGroup.from_store( store, attributes=attributes, exists_ok=exists_ok, @@ -828,6 +858,22 @@ def attrs(self) -> Attributes: def info(self) -> None: raise NotImplementedError + @property + def store(self) -> Store: + # Backwards compatibility for 2.x + return self._async_group.store + + @property + def read_only(self) -> bool: + # Backwards compatibility for 2.x + return self._async_group.read_only + + @property + def synchronizer(self) -> None: + # Backwards compatibility for 2.x + # Not implemented in 3.x yet. + return self._async_group.synchronizer + def update_attributes(self, new_attributes: dict[str, Any]) -> Group: self._sync(self._async_group.update_attributes(new_attributes)) return self @@ -848,18 +894,29 @@ def members(self, max_depth: int | None = 0) -> tuple[tuple[str, Array | Group], def __contains__(self, member: str) -> bool: return self._sync(self._async_group.contains(member)) - def group_keys(self) -> tuple[str, ...]: - return tuple(self._sync_iter(self._async_group.group_keys())) + def groups(self) -> Generator[tuple[str, Group], None]: + for name, async_group in self._sync_iter(self._async_group.groups()): + yield name, Group(async_group) + + def group_keys(self) -> Generator[str, None]: + for name, _ in self.groups(): + yield name - def groups(self) -> tuple[Group, ...]: - # TODO: in v2 this was a generator that return key: Group - return tuple(Group(obj) for obj in self._sync_iter(self._async_group.groups())) + def group_values(self) -> Generator[Group, None]: + for _, group in self.groups(): + yield group - def array_keys(self) -> tuple[str, ...]: - return tuple(self._sync_iter(self._async_group.array_keys())) + def arrays(self) -> Generator[tuple[str, Array], None]: + for name, async_array in self._sync_iter(self._async_group.arrays()): + yield name, Array(async_array) - def arrays(self) -> tuple[Array, ...]: - return tuple(Array(obj) for obj in self._sync_iter(self._async_group.arrays())) + def array_keys(self) -> Generator[str, None]: + for name, _ in self.arrays(): + yield name + + def array_values(self) -> Generator[Array, None]: + for _, array in self.arrays(): + yield array def tree(self, expand: bool = False, level: int | None = None) -> Any: return self._sync(self._async_group.tree(expand=expand, level=level)) @@ -887,6 +944,10 @@ def require_groups(self, *names: str) -> tuple[Group, ...]: """Convenience method to require multiple groups in a single call.""" return tuple(map(Group, self._sync(self._async_group.require_groups(*names)))) + def create(self, *args: Any, **kwargs: Any) -> Array: + # Backwards compatibility for 2.x + return self.create_array(*args, **kwargs) + def create_array( self, name: str, @@ -1058,29 +1119,43 @@ def require_array(self, name: str, **kwargs: Any) -> Array: """ return Array(self._sync(self._async_group.require_array(name, **kwargs))) - def empty(self, **kwargs: Any) -> Array: - return Array(self._sync(self._async_group.empty(**kwargs))) + def empty(self, *, name: str, shape: ChunkCoords, **kwargs: Any) -> Array: + return Array(self._sync(self._async_group.empty(name=name, shape=shape, **kwargs))) - def zeros(self, **kwargs: Any) -> Array: - return Array(self._sync(self._async_group.zeros(**kwargs))) + def zeros(self, *, name: str, shape: ChunkCoords, **kwargs: Any) -> Array: + return Array(self._sync(self._async_group.zeros(name=name, shape=shape, **kwargs))) - def ones(self, **kwargs: Any) -> Array: - return Array(self._sync(self._async_group.ones(**kwargs))) + def ones(self, *, name: str, shape: ChunkCoords, **kwargs: Any) -> Array: + return Array(self._sync(self._async_group.ones(name=name, shape=shape, **kwargs))) - def full(self, **kwargs: Any) -> Array: - return Array(self._sync(self._async_group.full(**kwargs))) + def full( + self, *, name: str, shape: ChunkCoords, fill_value: Any | None, **kwargs: Any + ) -> Array: + return Array( + self._sync( + self._async_group.full(name=name, shape=shape, fill_value=fill_value, **kwargs) + ) + ) - def empty_like(self, prototype: AsyncArray, **kwargs: Any) -> Array: - return Array(self._sync(self._async_group.empty_like(prototype, **kwargs))) + def empty_like(self, *, name: str, prototype: async_api.ArrayLike, **kwargs: Any) -> Array: + return Array( + self._sync(self._async_group.empty_like(name=name, prototype=prototype, **kwargs)) + ) - def zeros_like(self, prototype: AsyncArray, **kwargs: Any) -> Array: - return Array(self._sync(self._async_group.zeros_like(prototype, **kwargs))) + def zeros_like(self, *, name: str, prototype: async_api.ArrayLike, **kwargs: Any) -> Array: + return Array( + self._sync(self._async_group.zeros_like(name=name, prototype=prototype, **kwargs)) + ) - def ones_like(self, prototype: AsyncArray, **kwargs: Any) -> Array: - return Array(self._sync(self._async_group.ones_like(prototype, **kwargs))) + def ones_like(self, *, name: str, prototype: async_api.ArrayLike, **kwargs: Any) -> Array: + return Array( + self._sync(self._async_group.ones_like(name=name, prototype=prototype, **kwargs)) + ) - def full_like(self, prototype: AsyncArray, **kwargs: Any) -> Array: - return Array(self._sync(self._async_group.full_like(prototype, **kwargs))) + def full_like(self, *, name: str, prototype: async_api.ArrayLike, **kwargs: Any) -> Array: + return Array( + self._sync(self._async_group.full_like(name=name, prototype=prototype, **kwargs)) + ) def move(self, source: str, dest: str) -> None: return self._sync(self._async_group.move(source, dest)) diff --git a/src/zarr/core/metadata/v2.py b/src/zarr/core/metadata/v2.py index af7821bea7..34bdbb537f 100644 --- a/src/zarr/core/metadata/v2.py +++ b/src/zarr/core/metadata/v2.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Iterable +from enum import Enum from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -21,7 +22,7 @@ from zarr.core.array_spec import ArraySpec from zarr.core.chunk_grids import RegularChunkGrid from zarr.core.chunk_key_encodings import parse_separator -from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_dtype, parse_shapelike +from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_shapelike from zarr.core.config import config, parse_indexing_order from zarr.core.metadata.common import ArrayMetadata, parse_attributes @@ -100,9 +101,24 @@ def _json_convert( else: return o.descr if np.isscalar(o): - # convert numpy scalar to python type, and pass - # python types through - return getattr(o, "item", lambda: o)() + out: Any + if hasattr(o, "dtype") and o.dtype.kind == "M" and hasattr(o, "view"): + # https://github.com/zarr-developers/zarr-python/issues/2119 + # `.item()` on a datetime type might or might not return an + # integer, depending on the value. + # Explicitly cast to an int first, and then grab .item() + out = o.view("i8").item() + else: + # convert numpy scalar to python type, and pass + # python types through + out = getattr(o, "item", lambda: o)() + if isinstance(out, complex): + # python complex types are not JSON serializable, so we use the + # serialization defined in the zarr v3 spec + return [out.real, out.imag] + return out + if isinstance(o, Enum): + return o.name raise TypeError zarray_dict = self.to_dict() @@ -157,6 +173,11 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self: return replace(self, attributes=attributes) +def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]: + # todo: real validation + return np.dtype(data) + + def parse_zarr_format(data: object) -> Literal[2]: if data == 2: return 2 diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 10047cbb93..603cd343af 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, cast, overload if TYPE_CHECKING: @@ -24,7 +25,7 @@ from zarr.core.buffer import default_buffer_prototype from zarr.core.chunk_grids import ChunkGrid, RegularChunkGrid from zarr.core.chunk_key_encodings import ChunkKeyEncoding -from zarr.core.common import ZARR_JSON, parse_dtype, parse_named_configuration, parse_shapelike +from zarr.core.common import ZARR_JSON, parse_named_configuration, parse_shapelike from zarr.core.config import config from zarr.core.metadata.common import ArrayMetadata, parse_attributes from zarr.registry import get_codec_class @@ -70,6 +71,66 @@ def parse_dimension_names(data: object) -> tuple[str | None, ...] | None: raise TypeError(msg) +class V3JsonEncoder(json.JSONEncoder): + def __init__(self, *args: Any, **kwargs: Any): + self.indent = kwargs.pop("indent", config.get("json_indent")) + super().__init__(*args, **kwargs) + + def default(self, o: object) -> Any: + if isinstance(o, np.dtype): + return str(o) + if np.isscalar(o): + out: Any + if hasattr(o, "dtype") and o.dtype.kind == "M" and hasattr(o, "view"): + # https://github.com/zarr-developers/zarr-python/issues/2119 + # `.item()` on a datetime type might or might not return an + # integer, depending on the value. + # Explicitly cast to an int first, and then grab .item() + out = o.view("i8").item() + else: + # convert numpy scalar to python type, and pass + # python types through + out = getattr(o, "item", lambda: o)() + if isinstance(out, complex): + # python complex types are not JSON serializable, so we use the + # serialization defined in the zarr v3 spec + return [out.real, out.imag] + elif np.isnan(out): + return "NaN" + elif np.isinf(out): + return "Infinity" if out > 0 else "-Infinity" + return out + elif isinstance(o, Enum): + return o.name + # this serializes numcodecs compressors + # todo: implement to_dict for codecs + elif isinstance(o, numcodecs.abc.Codec): + config: dict[str, Any] = o.get_config() + return config + else: + return super().default(o) + + +def _replace_special_floats(obj: object) -> Any: + """Helper function to replace NaN/Inf/-Inf values with special strings + + Note: this cannot be done in the V3JsonEncoder because Python's `json.dumps` optimistically + converts NaN/Inf values to special types outside of the encoding step. + """ + if isinstance(obj, float): + if np.isnan(obj): + return "NaN" + elif np.isinf(obj): + return "Infinity" if obj > 0 else "-Infinity" + elif isinstance(obj, dict): + # Recursively replace in dictionaries + return {k: _replace_special_floats(v) for k, v in obj.items()} + elif isinstance(obj, list): + # Recursively replace in lists + return [_replace_special_floats(item) for item in obj] + return obj + + @dataclass(frozen=True, kw_only=True) class ArrayV3Metadata(ArrayMetadata): shape: ChunkCoords @@ -169,41 +230,8 @@ def encode_chunk_key(self, chunk_coords: ChunkCoords) -> str: return self.chunk_key_encoding.encode_chunk_key(chunk_coords) def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: - def _json_convert(o: object) -> Any: - if isinstance(o, np.dtype): - return str(o) - if np.isscalar(o): - out: Any - if hasattr(o, "dtype") and o.dtype.kind == "M" and hasattr(o, "view"): - # https://github.com/zarr-developers/zarr-python/issues/2119 - # `.item()` on a datetime type might or might not return an - # integer, depending on the value. - # Explicitly cast to an int first, and then grab .item() - out = o.view("i8").item() - else: - # convert numpy scalar to python type, and pass - # python types through - out = getattr(o, "item", lambda: o)() - if isinstance(out, complex): - # python complex types are not JSON serializable, so we use the - # serialization defined in the zarr v3 spec - return [out.real, out.imag] - return out - if isinstance(o, Enum): - return o.name - # this serializes numcodecs compressors - # todo: implement to_dict for codecs - elif isinstance(o, numcodecs.abc.Codec): - config: dict[str, Any] = o.get_config() - return config - raise TypeError - - json_indent = config.get("json_indent") - return { - ZARR_JSON: prototype.buffer.from_bytes( - json.dumps(self.to_dict(), default=_json_convert, indent=json_indent).encode() - ) - } + d = _replace_special_floats(self.to_dict()) + return {ZARR_JSON: prototype.buffer.from_bytes(json.dumps(d, cls=V3JsonEncoder).encode())} @classmethod def from_dict(cls, data: dict[str, JSON]) -> Self: @@ -215,6 +243,9 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: # check that the node_type attribute is correct _ = parse_node_type_array(_data.pop("node_type")) + # check that the data_type attribute is valid + _ = DataType(_data["data_type"]) + # dimension_names key is optional, normalize missing to `None` _data["dimension_names"] = _data.pop("dimension_names", None) # attributes key is optional, normalize missing to `None` @@ -260,23 +291,38 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self: @overload -def parse_fill_value(fill_value: object, dtype: BOOL_DTYPE) -> BOOL: ... +def parse_fill_value( + fill_value: int | float | complex | str | bytes | np.generic | Sequence[Any] | bool | None, + dtype: BOOL_DTYPE, +) -> BOOL: ... @overload -def parse_fill_value(fill_value: object, dtype: INTEGER_DTYPE) -> INTEGER: ... +def parse_fill_value( + fill_value: int | float | complex | str | bytes | np.generic | Sequence[Any] | bool | None, + dtype: INTEGER_DTYPE, +) -> INTEGER: ... @overload -def parse_fill_value(fill_value: object, dtype: FLOAT_DTYPE) -> FLOAT: ... +def parse_fill_value( + fill_value: int | float | complex | str | bytes | np.generic | Sequence[Any] | bool | None, + dtype: FLOAT_DTYPE, +) -> FLOAT: ... @overload -def parse_fill_value(fill_value: object, dtype: COMPLEX_DTYPE) -> COMPLEX: ... +def parse_fill_value( + fill_value: int | float | complex | str | bytes | np.generic | Sequence[Any] | bool | None, + dtype: COMPLEX_DTYPE, +) -> COMPLEX: ... @overload -def parse_fill_value(fill_value: object, dtype: np.dtype[Any]) -> Any: +def parse_fill_value( + fill_value: int | float | complex | str | bytes | np.generic | Sequence[Any] | bool | None, + dtype: np.dtype[Any], +) -> Any: # This dtype[Any] is unfortunately necessary right now. # See https://github.com/zarr-developers/zarr-python/issues/2131#issuecomment-2318010899 # for more details, but `dtype` here (which comes from `parse_dtype`) @@ -288,7 +334,7 @@ def parse_fill_value(fill_value: object, dtype: np.dtype[Any]) -> Any: def parse_fill_value( - fill_value: object, + fill_value: int | float | complex | str | bytes | np.generic | Sequence[Any] | bool | None, dtype: BOOL_DTYPE | INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE | np.dtype[Any], ) -> BOOL | INTEGER | FLOAT | COMPLEX | Any: """ @@ -322,13 +368,40 @@ def parse_fill_value( else: msg = ( f"Got an invalid fill value for complex data type {dtype}." - f"Expected a sequence with 2 elements, but {fill_value} has " + f"Expected a sequence with 2 elements, but {fill_value!r} has " f"length {len(fill_value)}." ) raise ValueError(msg) - msg = f"Cannot parse non-string sequence {fill_value} as a scalar with type {dtype}." + msg = f"Cannot parse non-string sequence {fill_value!r} as a scalar with type {dtype}." raise TypeError(msg) - return dtype.type(fill_value) # type: ignore[arg-type] + + # Cast the fill_value to the given dtype + try: + # This warning filter can be removed after Zarr supports numpy>=2.0 + # The warning is saying that the future behavior of out of bounds casting will be to raise + # an OverflowError. In the meantime, we allow overflow and catch cases where + # fill_value != casted_value below. + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + casted_value = np.dtype(dtype).type(fill_value) + except (ValueError, OverflowError, TypeError) as e: + raise ValueError(f"fill value {fill_value!r} is not valid for dtype {dtype}") from e + # Check if the value is still representable by the dtype + if fill_value == "NaN" and np.isnan(casted_value): + pass + elif fill_value in ["Infinity", "-Infinity"] and not np.isfinite(casted_value): + pass + elif dtype.kind == "f": + # float comparison is not exact, especially when dtype int: @@ -360,8 +436,11 @@ def byte_count(self) -> int: DataType.uint16: 2, DataType.uint32: 4, DataType.uint64: 8, + DataType.float16: 2, DataType.float32: 4, DataType.float64: 8, + DataType.complex64: 8, + DataType.complex128: 16, } return data_type_byte_counts[self] @@ -381,8 +460,11 @@ def to_numpy_shortname(self) -> str: DataType.uint16: "u2", DataType.uint32: "u4", DataType.uint64: "u8", + DataType.float16: "f2", DataType.float32: "f4", DataType.float64: "f8", + DataType.complex64: "c8", + DataType.complex128: "c16", } return data_type_to_numpy[self] @@ -399,7 +481,24 @@ def from_dtype(cls, dtype: np.dtype[Any]) -> DataType: " np.dtype[Any]: + try: + dtype = np.dtype(data) + except (ValueError, TypeError) as e: + raise ValueError(f"Invalid V3 data_type: {data}") from e + # check that this is a valid v3 data_type + try: + _ = DataType.from_dtype(dtype) + except KeyError as e: + raise ValueError(f"Invalid V3 data_type: {dtype}") from e + + return dtype diff --git a/src/zarr/core/sync.py b/src/zarr/core/sync.py index ff7f9a43af..db3dce79b2 100644 --- a/src/zarr/core/sync.py +++ b/src/zarr/core/sync.py @@ -113,6 +113,23 @@ def _get_loop() -> asyncio.AbstractEventLoop: return loop[0] +async def _collect_aiterator(data: AsyncIterator[T]) -> tuple[T, ...]: + """ + Collect an entire async iterator into a tuple + """ + result = [] + async for x in data: + result.append(x) + return tuple(result) + + +def collect_aiterator(data: AsyncIterator[T]) -> tuple[T, ...]: + """ + Synchronously collect an entire async iterator into a tuple. + """ + return sync(_collect_aiterator(data)) + + class SyncMixin: def _sync(self, coroutine: Coroutine[Any, Any, T]) -> T: # TODO: refactor this to to take *args and **kwargs and pass those to the method diff --git a/src/zarr/store/common.py b/src/zarr/store/common.py index 8028c9af3d..196479dd67 100644 --- a/src/zarr/store/common.py +++ b/src/zarr/store/common.py @@ -11,6 +11,8 @@ from zarr.store.local import LocalStore from zarr.store.memory import MemoryStore +# from zarr.store.remote import RemoteStore + if TYPE_CHECKING: from zarr.core.buffer import BufferPrototype from zarr.core.common import AccessModeLiteral @@ -75,30 +77,69 @@ def __eq__(self, other: Any) -> bool: async def make_store_path( - store_like: StoreLike | None, *, mode: AccessModeLiteral | None = None + store_like: StoreLike | None, + *, + mode: AccessModeLiteral | None = None, + storage_options: dict[str, Any] | None = None, ) -> StorePath: + from zarr.store.remote import RemoteStore # circular import + + used_storage_options = False + if isinstance(store_like, StorePath): if mode is not None: assert AccessMode.from_literal(mode) == store_like.store.mode - return store_like + result = store_like elif isinstance(store_like, Store): if mode is not None: assert AccessMode.from_literal(mode) == store_like.mode await store_like._ensure_open() - return StorePath(store_like) + result = StorePath(store_like) elif store_like is None: if mode is None: mode = "w" # exception to the default mode = 'r' - return StorePath(await MemoryStore.open(mode=mode)) + result = StorePath(await MemoryStore.open(mode=mode)) elif isinstance(store_like, Path): - return StorePath(await LocalStore.open(root=store_like, mode=mode or "r")) + result = StorePath(await LocalStore.open(root=store_like, mode=mode or "r")) elif isinstance(store_like, str): - return StorePath(await LocalStore.open(root=Path(store_like), mode=mode or "r")) + storage_options = storage_options or {} + + if _is_fsspec_uri(store_like): + used_storage_options = True + result = StorePath( + RemoteStore.from_url(store_like, storage_options=storage_options, mode=mode or "r") + ) + else: + result = StorePath(await LocalStore.open(root=Path(store_like), mode=mode or "r")) elif isinstance(store_like, dict): # We deliberate only consider dict[str, Buffer] here, and not arbitrary mutable mappings. # By only allowing dictionaries, which are in-memory, we know that MemoryStore appropriate. - return StorePath(await MemoryStore.open(store_dict=store_like, mode=mode)) - raise TypeError + result = StorePath(await MemoryStore.open(store_dict=store_like, mode=mode)) + else: + msg = f"Unsupported type for store_like: '{type(store_like).__name__}'" # type: ignore[unreachable] + raise TypeError(msg) + + if storage_options and not used_storage_options: + msg = "'storage_options' was provided but unused. 'storage_options' is only used for fsspec filesystem stores." + raise TypeError(msg) + + return result + + +def _is_fsspec_uri(uri: str) -> bool: + """ + Check if a URI looks like a non-local fsspec URI. + + Examples + -------- + >>> _is_fsspec_uri("s3://bucket") + True + >>> _is_fsspec_uri("my-directory") + False + >>> _is_fsspec_uri("local://my-directory") + False + """ + return "://" in uri or "::" in uri and "local://" not in uri async def ensure_no_existing_node(store_path: StorePath, zarr_format: ZarrFormat) -> None: diff --git a/src/zarr/store/local.py b/src/zarr/store/local.py index 5fd48c2db0..c78837586f 100644 --- a/src/zarr/store/local.py +++ b/src/zarr/store/local.py @@ -191,7 +191,9 @@ async def list(self) -> AsyncGenerator[str, None]: yield str(p).replace(to_strip, "") async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: - """Retrieve all keys in the store with a given prefix. + """ + Retrieve all keys in the store that begin with a given prefix. Keys are returned with the + common leading prefix removed. Parameters ---------- @@ -201,14 +203,10 @@ async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: ------- AsyncGenerator[str, None] """ + to_strip = os.path.join(str(self.root / prefix)) for p in (self.root / prefix).rglob("*"): if p.is_file(): - yield str(p) - - to_strip = str(self.root) + "/" - for p in (self.root / prefix).rglob("*"): - if p.is_file(): - yield str(p).replace(to_strip, "") + yield str(p.relative_to(to_strip)) async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: """ diff --git a/src/zarr/store/memory.py b/src/zarr/store/memory.py index 13e289f374..7baa6aee26 100644 --- a/src/zarr/store/memory.py +++ b/src/zarr/store/memory.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import AsyncGenerator, MutableMapping -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from zarr.abc.store import Store from zarr.core.buffer import Buffer, gpu @@ -55,12 +55,6 @@ def __eq__(self, other: object) -> bool: and self.mode == other.mode ) - def __setstate__(self, state: Any) -> None: - raise NotImplementedError(f"{type(self)} cannot be pickled") - - def __getstate__(self) -> None: - raise NotImplementedError(f"{type(self)} cannot be pickled") - async def get( self, key: str, @@ -124,9 +118,21 @@ async def list(self) -> AsyncGenerator[str, None]: async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: for key in self._store_dict: if key.startswith(prefix): - yield key + yield key.removeprefix(prefix) async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: + """ + Retrieve all keys in the store that begin with a given prefix. Keys are returned with the + common leading prefix removed. + + Parameters + ---------- + prefix : str + + Returns + ------- + AsyncGenerator[str, None] + """ if prefix.endswith("/"): prefix = prefix[:-1] diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index e3e2ba3447..ecb46a31d3 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -11,12 +11,18 @@ from collections.abc import AsyncGenerator from fsspec.asyn import AsyncFileSystem - from upath import UPath from zarr.core.buffer import Buffer, BufferPrototype from zarr.core.common import AccessModeLiteral, BytesLike +ALLOWED_EXCEPTIONS: tuple[type[Exception], ...] = ( + FileNotFoundError, + IsADirectoryError, + NotADirectoryError, +) + + class RemoteStore(Store): # based on FSSpec supports_writes: bool = True @@ -24,21 +30,15 @@ class RemoteStore(Store): supports_partial_writes: bool = False supports_listing: bool = True - _fs: AsyncFileSystem - _url: str - path: str + fs: AsyncFileSystem allowed_exceptions: tuple[type[Exception], ...] def __init__( self, - url: UPath | str, + fs: AsyncFileSystem, mode: AccessModeLiteral = "r", - allowed_exceptions: tuple[type[Exception], ...] = ( - FileNotFoundError, - IsADirectoryError, - NotADirectoryError, - ), - **storage_options: Any, + path: str = "/", + allowed_exceptions: tuple[type[Exception], ...] = ALLOWED_EXCEPTIONS, ): """ Parameters @@ -51,54 +51,58 @@ def __init__( this must not be used. """ super().__init__(mode=mode) - self._storage_options = storage_options - if isinstance(url, str): - self._url = url.rstrip("/") - self._fs, _path = fsspec.url_to_fs(url, **storage_options) - self.path = _path.rstrip("/") - elif hasattr(url, "protocol") and hasattr(url, "fs"): - # is UPath-like - but without importing - if storage_options: - raise ValueError( - "If constructed with a UPath object, no additional " - "storage_options are allowed" - ) - # n.b. UPath returns the url and path attributes with a trailing /, at least for s3 - # that trailing / must be removed to compose with the store interface - self._url = str(url).rstrip("/") - self.path = url.path.rstrip("/") - self._fs = url.fs - else: - raise ValueError(f"URL not understood, {url}") + self.fs = fs + self.path = path self.allowed_exceptions = allowed_exceptions - # test instantiate file system - if not self._fs.async_impl: - raise TypeError("FileSystem needs to support async operations") + + if not self.fs.async_impl: + raise TypeError("Filesystem needs to support async operations.") + + @classmethod + def from_upath( + cls, + upath: Any, + mode: AccessModeLiteral = "r", + allowed_exceptions: tuple[type[Exception], ...] = ALLOWED_EXCEPTIONS, + ) -> RemoteStore: + return cls( + fs=upath.fs, + path=upath.path.rstrip("/"), + mode=mode, + allowed_exceptions=allowed_exceptions, + ) + + @classmethod + def from_url( + cls, + url: str, + storage_options: dict[str, Any] | None = None, + mode: AccessModeLiteral = "r", + allowed_exceptions: tuple[type[Exception], ...] = ALLOWED_EXCEPTIONS, + ) -> RemoteStore: + fs, path = fsspec.url_to_fs(url, **storage_options) + return cls(fs=fs, path=path, mode=mode, allowed_exceptions=allowed_exceptions) async def clear(self) -> None: try: - for subpath in await self._fs._find(self.path, withdirs=True): + for subpath in await self.fs._find(self.path, withdirs=True): if subpath != self.path: - await self._fs._rm(subpath, recursive=True) + await self.fs._rm(subpath, recursive=True) except FileNotFoundError: pass async def empty(self) -> bool: - return not await self._fs._find(self.path, withdirs=True) - - def __str__(self) -> str: - return f"{self._url}" + return not await self.fs._find(self.path, withdirs=True) def __repr__(self) -> str: - return f"" + return f"" def __eq__(self, other: object) -> bool: return ( isinstance(other, type(self)) and self.path == other.path and self.mode == other.mode - and self._url == other._url - # and self._storage_options == other._storage_options # FIXME: this isn't working for some reason + and self.fs == other.fs ) async def get( @@ -123,9 +127,9 @@ async def get( end = None value = prototype.buffer.from_bytes( await ( - self._fs._cat_file(path, start=byte_range[0], end=end) + self.fs._cat_file(path, start=byte_range[0], end=end) if byte_range - else self._fs._cat_file(path) + else self.fs._cat_file(path) ) ) @@ -152,13 +156,13 @@ async def set( # write data if byte_range: raise NotImplementedError - await self._fs._pipe_file(path, value.to_bytes()) + await self.fs._pipe_file(path, value.to_bytes()) async def delete(self, key: str) -> None: self._check_writable() path = _dereference_path(self.path, key) try: - await self._fs._rm(path) + await self.fs._rm(path) except FileNotFoundError: pass except self.allowed_exceptions: @@ -166,7 +170,7 @@ async def delete(self, key: str) -> None: async def exists(self, key: str) -> bool: path = _dereference_path(self.path, key) - exists: bool = await self._fs._exists(path) + exists: bool = await self.fs._exists(path) return exists async def get_partial_values( @@ -189,7 +193,7 @@ async def get_partial_values( else: return [] # TODO: expectations for exceptions or missing keys? - res = await self._fs._cat_ranges(list(paths), starts, stops, on_error="return") + res = await self.fs._cat_ranges(list(paths), starts, stops, on_error="return") # the following is an s3-specific condition we probably don't want to leak res = [b"" if (isinstance(r, OSError) and "not satisfiable" in str(r)) else r for r in res] for r in res: @@ -202,19 +206,33 @@ async def set_partial_values(self, key_start_values: list[tuple[str, int, BytesL raise NotImplementedError async def list(self) -> AsyncGenerator[str, None]: - allfiles = await self._fs._find(self.path, detail=False, withdirs=False) + allfiles = await self.fs._find(self.path, detail=False, withdirs=False) for onefile in (a.replace(self.path + "/", "") for a in allfiles): yield onefile async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: prefix = f"{self.path}/{prefix.rstrip('/')}" try: - allfiles = await self._fs._ls(prefix, detail=False) + allfiles = await self.fs._ls(prefix, detail=False) except FileNotFoundError: return for onefile in (a.replace(prefix + "/", "") for a in allfiles): yield onefile.removeprefix(self.path).removeprefix("/") async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: - for onefile in await self._fs._ls(prefix, detail=False): - yield onefile + """ + Retrieve all keys in the store that begin with a given prefix. Keys are returned with the + common leading prefix removed. + + Parameters + ---------- + prefix : str + + Returns + ------- + AsyncGenerator[str, None] + """ + + find_str = "/".join([self.path, prefix]) + for onefile in await self.fs._find(find_str, detail=False, maxdepth=None, withdirs=False): + yield onefile.removeprefix(find_str) diff --git a/src/zarr/store/zip.py b/src/zarr/store/zip.py index ea31ad934a..2e4927aced 100644 --- a/src/zarr/store/zip.py +++ b/src/zarr/store/zip.py @@ -209,9 +209,21 @@ async def list(self) -> AsyncGenerator[str, None]: yield key async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: + """ + Retrieve all keys in the store that begin with a given prefix. Keys are returned with the + common leading prefix removed. + + Parameters + ---------- + prefix : str + + Returns + ------- + AsyncGenerator[str, None] + """ async for key in self.list(): if key.startswith(prefix): - yield key + yield key.removeprefix(prefix) async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: if prefix.endswith("/"): diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index a08b6960db..ebd4b85c90 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -3,9 +3,9 @@ import pytest -import zarr.api.asynchronous from zarr.abc.store import AccessMode, Store from zarr.core.buffer import Buffer, default_buffer_prototype +from zarr.core.sync import _collect_aiterator from zarr.store._utils import _normalize_interval_index from zarr.testing.utils import assert_bytes_equal @@ -58,7 +58,7 @@ def test_store_eq(self, store: S, store_kwargs: dict[str, Any]) -> None: store2 = self.store_cls(**store_kwargs) assert store == store2 - def test_serizalizable_store(self, store: S) -> None: + def test_serializable_store(self, store: S) -> None: foo = pickle.dumps(store) assert pickle.loads(foo) == store @@ -123,6 +123,18 @@ async def test_set(self, store: S, key: str, data: bytes) -> None: observed = self.get(store, key) assert_bytes_equal(observed, data_buf) + async def test_set_many(self, store: S) -> None: + """ + Test that a dict of key : value pairs can be inserted into the store via the + `_set_many` method. + """ + keys = ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"] + data_buf = [self.buffer_cls.from_bytes(k.encode()) for k in keys] + store_dict = dict(zip(keys, data_buf, strict=True)) + await store._set_many(store_dict.items()) + for k, v in store_dict.items(): + assert self.get(store, k).to_bytes() == v.to_bytes() + @pytest.mark.parametrize( "key_ranges", ( @@ -185,76 +197,57 @@ async def test_clear(self, store: S) -> None: assert await store.empty() async def test_list(self, store: S) -> None: - assert [k async for k in store.list()] == [] - await store.set("foo/zarr.json", self.buffer_cls.from_bytes(b"bar")) - keys = [k async for k in store.list()] - assert keys == ["foo/zarr.json"], keys - - expected = ["foo/zarr.json"] - for i in range(10): - key = f"foo/c/{i}" - expected.append(key) - await store.set( - f"foo/c/{i}", self.buffer_cls.from_bytes(i.to_bytes(length=3, byteorder="little")) - ) + assert await _collect_aiterator(store.list()) == () + prefix = "foo" + data = self.buffer_cls.from_bytes(b"") + store_dict = { + prefix + "/zarr.json": data, + **{prefix + f"/c/{idx}": data for idx in range(10)}, + } + await store._set_many(store_dict.items()) + expected_sorted = sorted(store_dict.keys()) + observed = await _collect_aiterator(store.list()) + observed_sorted = sorted(observed) + assert observed_sorted == expected_sorted - @pytest.mark.xfail async def test_list_prefix(self, store: S) -> None: - # TODO: we currently don't use list_prefix anywhere - raise NotImplementedError + """ + Test that the `list_prefix` method works as intended. Given a prefix, it should return + all the keys in storage that start with this prefix. Keys should be returned with the shared + prefix removed. + """ + prefixes = ("", "a/", "a/b/", "a/b/c/") + data = self.buffer_cls.from_bytes(b"") + fname = "zarr.json" + store_dict = {p + fname: data for p in prefixes} + + await store._set_many(store_dict.items()) + + for prefix in prefixes: + observed = tuple(sorted(await _collect_aiterator(store.list_prefix(prefix)))) + expected: tuple[str, ...] = () + for key in store_dict.keys(): + if key.startswith(prefix): + expected += (key.removeprefix(prefix),) + expected = tuple(sorted(expected)) + assert observed == expected async def test_list_dir(self, store: S) -> None: - out = [k async for k in store.list_dir("")] - assert out == [] - assert [k async for k in store.list_dir("foo")] == [] - await store.set("foo/zarr.json", self.buffer_cls.from_bytes(b"bar")) - await store.set("group-0/zarr.json", self.buffer_cls.from_bytes(b"\x01")) # group - await store.set("group-0/group-1/zarr.json", self.buffer_cls.from_bytes(b"\x01")) # group - await store.set("group-0/group-1/a1/zarr.json", self.buffer_cls.from_bytes(b"\x01")) - await store.set("group-0/group-1/a2/zarr.json", self.buffer_cls.from_bytes(b"\x01")) - await store.set("group-0/group-1/a3/zarr.json", self.buffer_cls.from_bytes(b"\x01")) - - keys_expected = ["foo", "group-0"] - keys_observed = [k async for k in store.list_dir("")] - assert set(keys_observed) == set(keys_expected) - - keys_expected = ["zarr.json"] - keys_observed = [k async for k in store.list_dir("foo")] - - assert len(keys_observed) == len(keys_expected), keys_observed - assert set(keys_observed) == set(keys_expected), keys_observed - - keys_observed = [k async for k in store.list_dir("foo/")] - assert len(keys_expected) == len(keys_observed), keys_observed - assert set(keys_observed) == set(keys_expected), keys_observed - - keys_observed = [k async for k in store.list_dir("group-0")] - keys_expected = ["zarr.json", "group-1"] - - assert len(keys_observed) == len(keys_expected), keys_observed - assert set(keys_observed) == set(keys_expected), keys_observed - - keys_observed = [k async for k in store.list_dir("group-0/")] - assert len(keys_expected) == len(keys_observed), keys_observed - assert set(keys_observed) == set(keys_expected), keys_observed + root = "foo" + store_dict = { + root + "/zarr.json": self.buffer_cls.from_bytes(b"bar"), + root + "/c/1": self.buffer_cls.from_bytes(b"\x01"), + } - keys_observed = [k async for k in store.list_dir("group-0/group-1")] - keys_expected = ["zarr.json", "a1", "a2", "a3"] + assert await _collect_aiterator(store.list_dir("")) == () + assert await _collect_aiterator(store.list_dir(root)) == () - assert len(keys_observed) == len(keys_expected), keys_observed - assert set(keys_observed) == set(keys_expected), keys_observed + await store._set_many(store_dict.items()) - keys_observed = [k async for k in store.list_dir("group-0/group-1")] - assert len(keys_expected) == len(keys_observed), keys_observed - assert set(keys_observed) == set(keys_expected), keys_observed + keys_observed = await _collect_aiterator(store.list_dir(root)) + keys_expected = {k.removeprefix(root + "/").split("/")[0] for k in store_dict.keys()} - async def test_set_get(self, store_kwargs: dict[str, Any]) -> None: - kwargs = {**store_kwargs, **{"mode": "w"}} - store = self.store_cls(**kwargs) - await zarr.api.asynchronous.open_array(store=store, path="a", mode="w", shape=(4,)) - keys = [x async for x in store.list()] - assert keys == ["a/zarr.json"] + assert sorted(keys_observed) == sorted(keys_expected) - # no errors - await zarr.api.asynchronous.open_array(store=store, path="a", mode="r") - await zarr.api.asynchronous.open_array(store=store, path="a", mode="a") + keys_observed = await _collect_aiterator(store.list_dir(root + "/")) + assert sorted(keys_expected) == sorted(keys_observed) diff --git a/src/zarr/testing/strategies.py b/src/zarr/testing/strategies.py index 77e51f57e0..77c1e0da51 100644 --- a/src/zarr/testing/strategies.py +++ b/src/zarr/testing/strategies.py @@ -36,11 +36,14 @@ paths = st.just("/") | keys np_arrays = npst.arrays( # TODO: re-enable timedeltas once they are supported - dtype=npst.scalar_dtypes().filter(lambda x: x.kind != "m"), + dtype=npst.scalar_dtypes().filter( + lambda x: (x.kind not in ["m", "M"]) and (x.byteorder not in [">"]) + ), shape=npst.array_shapes(max_dims=4), ) stores = st.builds(MemoryStore, st.just({}), mode=st.just("w")) compressors = st.sampled_from([None, "default"]) +format = st.sampled_from([2, 3]) @st.composite # type: ignore[misc] @@ -70,12 +73,14 @@ def arrays( paths: st.SearchStrategy[None | str] = paths, array_names: st.SearchStrategy = array_names, attrs: st.SearchStrategy = attrs, + format: st.SearchStrategy = format, ) -> Array: store = draw(stores) nparray, chunks = draw(np_array_and_chunks(arrays=arrays)) path = draw(paths) name = draw(array_names) attributes = draw(attrs) + zarr_format = draw(format) # compressor = draw(compressors) # TODO: clean this up @@ -100,7 +105,7 @@ def arrays( expected_attrs = {} if attributes is None else attributes array_path = path + ("/" if not path.endswith("/") else "") + name - root = Group.create(store) + root = Group.from_store(store, zarr_format=zarr_format) fill_value_args: tuple[Any, ...] = tuple() if nparray.dtype.kind == "M": m = re.search(r"\[(.+)\]", nparray.dtype.str) diff --git a/tests/v3/conftest.py b/tests/v3/conftest.py index d216094401..20b77d888e 100644 --- a/tests/v3/conftest.py +++ b/tests/v3/conftest.py @@ -98,7 +98,7 @@ async def async_group(request: pytest.FixtureRequest, tmpdir: LEGACY_PATH) -> As param: AsyncGroupRequest = request.param store = await parse_store(param.store, str(tmpdir)) - agroup = await AsyncGroup.create( + agroup = await AsyncGroup.from_store( store, attributes=param.attributes, zarr_format=param.zarr_format, diff --git a/tests/v3/test_api.py b/tests/v3/test_api.py index ddfab587cc..1b4330eef3 100644 --- a/tests/v3/test_api.py +++ b/tests/v3/test_api.py @@ -1,4 +1,5 @@ import pathlib +import warnings import numpy as np import pytest @@ -7,7 +8,7 @@ import zarr from zarr import Array, Group from zarr.abc.store import Store -from zarr.api.synchronous import create, load, open, open_group, save, save_array, save_group +from zarr.api.synchronous import create, group, load, open, open_group, save, save_array, save_group from zarr.core.common import ZarrFormat from zarr.store.memory import MemoryStore @@ -108,7 +109,7 @@ def test_save_errors() -> None: save_group("data/group.zarr") with pytest.raises(TypeError): # no array provided - save_array("data/group.zarr") # type: ignore[call-arg] + save_array("data/group.zarr") with pytest.raises(ValueError): # no arrays provided save("data/group.zarr") @@ -118,9 +119,11 @@ def test_open_with_mode_r(tmp_path: pathlib.Path) -> None: # 'r' means read only (must exist) with pytest.raises(FileNotFoundError): zarr.open(store=tmp_path, mode="r") - zarr.ones(store=tmp_path, shape=(3, 3)) + z1 = zarr.ones(store=tmp_path, shape=(3, 3)) + assert z1.fill_value == 1 z2 = zarr.open(store=tmp_path, mode="r") assert isinstance(z2, Array) + assert z2.fill_value == 1 assert (z2[:] == 1).all() with pytest.raises(ValueError): z2[:] = 3 @@ -878,3 +881,37 @@ def test_tree() -> None: # # bad option # with pytest.raises(TypeError): # copy(source["foo"], dest, dry_run=True, log=True) + + +def test_open_positional_args_deprecated() -> None: + store = MemoryStore({}, mode="w") + with pytest.warns(FutureWarning, match="pass"): + open(store, "w", shape=(1,)) + + +def test_save_array_positional_args_deprecated() -> None: + store = MemoryStore({}, mode="w") + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="zarr_version is deprecated", category=DeprecationWarning + ) + with pytest.warns(FutureWarning, match="pass"): + save_array( + store, + np.ones( + 1, + ), + 3, + ) + + +def test_group_positional_args_deprecated() -> None: + store = MemoryStore({}, mode="w") + with pytest.warns(FutureWarning, match="pass"): + group(store, True) + + +def test_open_group_positional_args_deprecated() -> None: + store = MemoryStore({}, mode="w") + with pytest.warns(FutureWarning, match="pass"): + open_group(store, "w") diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py index b7beb63b1c..b3362c52b0 100644 --- a/tests/v3/test_array.py +++ b/tests/v3/test_array.py @@ -5,6 +5,7 @@ import pytest from zarr import Array, AsyncArray, Group +from zarr.core.buffer.cpu import NDBuffer from zarr.core.common import ZarrFormat from zarr.errors import ContainsArrayError, ContainsGroupError from zarr.store import LocalStore, MemoryStore @@ -25,7 +26,7 @@ def test_array_creation_existing_node( Check that an existing array or group is handled as expected during array creation. """ spath = StorePath(store) - group = Group.create(spath, zarr_format=zarr_format) + group = Group.from_store(spath, zarr_format=zarr_format) expected_exception: type[ContainsArrayError] | type[ContainsGroupError] if extant_node == "array": expected_exception = ContainsArrayError @@ -76,7 +77,7 @@ def test_array_name_properties_no_group( def test_array_name_properties_with_group( store: LocalStore | MemoryStore, zarr_format: ZarrFormat ) -> None: - root = Group.create(store=store, zarr_format=zarr_format) + root = Group.from_store(store=store, zarr_format=zarr_format) foo = root.create_array("foo", shape=(100,), chunks=(10,), dtype="i4") assert foo.path == "foo" assert foo.name == "/foo" @@ -121,8 +122,10 @@ def test_array_v3_fill_value_default( @pytest.mark.parametrize("store", ["memory"], indirect=True) -@pytest.mark.parametrize("fill_value", [False, 0.0, 1, 2.3]) -@pytest.mark.parametrize("dtype_str", ["bool", "uint8", "float32", "complex64"]) +@pytest.mark.parametrize( + "dtype_str,fill_value", + [("bool", True), ("uint8", 99), ("float32", -99.9), ("complex64", 3 + 4j)], +) def test_array_v3_fill_value(store: MemoryStore, fill_value: int, dtype_str: str) -> None: shape = (10,) arr = Array.create( @@ -138,6 +141,47 @@ def test_array_v3_fill_value(store: MemoryStore, fill_value: int, dtype_str: str assert arr.fill_value.dtype == arr.dtype +def test_create_positional_args_deprecated() -> None: + store = MemoryStore({}, mode="w") + with pytest.warns(FutureWarning, match="Pass"): + Array.create(store, (2, 2), dtype="f8") + + +def test_selection_positional_args_deprecated() -> None: + store = MemoryStore({}, mode="w") + arr = Array.create(store, shape=(2, 2), dtype="f8") + + with pytest.warns(FutureWarning, match="Pass out"): + arr.get_basic_selection(..., NDBuffer(array=np.empty((2, 2)))) + + with pytest.warns(FutureWarning, match="Pass fields"): + arr.set_basic_selection(..., 1, None) + + with pytest.warns(FutureWarning, match="Pass out"): + arr.get_orthogonal_selection(..., NDBuffer(array=np.empty((2, 2)))) + + with pytest.warns(FutureWarning, match="Pass"): + arr.set_orthogonal_selection(..., 1, None) + + with pytest.warns(FutureWarning, match="Pass"): + arr.get_mask_selection(np.zeros((2, 2), dtype=bool), NDBuffer(array=np.empty((0,)))) + + with pytest.warns(FutureWarning, match="Pass"): + arr.set_mask_selection(np.zeros((2, 2), dtype=bool), 1, None) + + with pytest.warns(FutureWarning, match="Pass"): + arr.get_coordinate_selection(([0, 1], [0, 1]), NDBuffer(array=np.empty((2,)))) + + with pytest.warns(FutureWarning, match="Pass"): + arr.set_coordinate_selection(([0, 1], [0, 1]), 1, None) + + with pytest.warns(FutureWarning, match="Pass"): + arr.get_block_selection((0, slice(None)), NDBuffer(array=np.empty((2, 2)))) + + with pytest.warns(FutureWarning, match="Pass"): + arr.set_block_selection((0, slice(None)), 1, None) + + @pytest.mark.parametrize("store", ["memory"], indirect=True) async def test_array_v3_nan_fill_value(store: MemoryStore) -> None: shape = (10,) diff --git a/tests/v3/test_attributes.py b/tests/v3/test_attributes.py new file mode 100644 index 0000000000..65b6a02e8d --- /dev/null +++ b/tests/v3/test_attributes.py @@ -0,0 +1,13 @@ +import zarr.core +import zarr.core.attributes +import zarr.store + + +def test_put() -> None: + store = zarr.store.MemoryStore({}, mode="w") + attrs = zarr.core.attributes.Attributes( + zarr.Group.from_store(store, attributes={"a": 1, "b": 2}) + ) + attrs.put({"a": 3, "c": 4}) + expected = {"a": 3, "c": 4} + assert dict(attrs) == expected diff --git a/tests/v3/test_buffer.py b/tests/v3/test_buffer.py index 5a313dc1ab..cde3f85780 100644 --- a/tests/v3/test_buffer.py +++ b/tests/v3/test_buffer.py @@ -32,6 +32,10 @@ cp = None +if TYPE_CHECKING: + import types + + def test_nd_array_like(xp: types.ModuleType) -> None: ary = xp.arange(10) assert isinstance(ary, ArrayLike) diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index d5fb9e7b5a..c8310f33e5 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -51,7 +51,7 @@ def test_group_init(store: Store, zarr_format: ZarrFormat) -> None: """ Test that initializing a group from an asyncgroup works. """ - agroup = sync(AsyncGroup.create(store=store, zarr_format=zarr_format)) + agroup = sync(AsyncGroup.from_store(store=store, zarr_format=zarr_format)) group = Group(agroup) assert group._async_group == agroup @@ -60,7 +60,7 @@ def test_group_name_properties(store: Store, zarr_format: ZarrFormat) -> None: """ Test basic properties of groups """ - root = Group.create(store=store, zarr_format=zarr_format) + root = Group.from_store(store=store, zarr_format=zarr_format) assert root.path == "" assert root.name == "/" assert root.basename == "" @@ -178,16 +178,18 @@ def test_group(store: Store, zarr_format: ZarrFormat) -> None: def test_group_create(store: Store, exists_ok: bool, zarr_format: ZarrFormat) -> None: """ - Test that `Group.create` works as expected. + Test that `Group.from_store` works as expected. """ attributes = {"foo": 100} - group = Group.create(store, attributes=attributes, zarr_format=zarr_format, exists_ok=exists_ok) + group = Group.from_store( + store, attributes=attributes, zarr_format=zarr_format, exists_ok=exists_ok + ) assert group.attrs == attributes if not exists_ok: with pytest.raises(ContainsGroupError): - group = Group.create( + group = Group.from_store( store, attributes=attributes, exists_ok=exists_ok, zarr_format=zarr_format ) @@ -203,7 +205,7 @@ def test_group_open(store: Store, zarr_format: ZarrFormat, exists_ok: bool) -> N # create the group attrs = {"path": "foo"} - group_created = Group.create( + group_created = Group.from_store( store, attributes=attrs, zarr_format=zarr_format, exists_ok=exists_ok ) assert group_created.attrs == attrs @@ -214,9 +216,9 @@ def test_group_open(store: Store, zarr_format: ZarrFormat, exists_ok: bool) -> N new_attrs = {"path": "bar"} if not exists_ok: with pytest.raises(ContainsGroupError): - Group.create(store, attributes=attrs, zarr_format=zarr_format, exists_ok=exists_ok) + Group.from_store(store, attributes=attrs, zarr_format=zarr_format, exists_ok=exists_ok) else: - group_created_again = Group.create( + group_created_again = Group.from_store( store, attributes=new_attrs, zarr_format=zarr_format, exists_ok=exists_ok ) assert group_created_again.attrs == new_attrs @@ -229,7 +231,7 @@ def test_group_getitem(store: Store, zarr_format: ZarrFormat) -> None: Test the `Group.__getitem__` method. """ - group = Group.create(store, zarr_format=zarr_format) + group = Group.from_store(store, zarr_format=zarr_format) subgroup = group.create_group(name="subgroup") subarray = group.create_array(name="subarray", shape=(10,), chunk_shape=(10,)) @@ -246,7 +248,7 @@ def test_group_delitem(store: Store, zarr_format: ZarrFormat) -> None: if not store.supports_deletes: pytest.skip("store does not support deletes") - group = Group.create(store, zarr_format=zarr_format) + group = Group.from_store(store, zarr_format=zarr_format) subgroup = group.create_group(name="subgroup") subarray = group.create_array(name="subarray", shape=(10,), chunk_shape=(10,)) @@ -267,7 +269,7 @@ def test_group_iter(store: Store, zarr_format: ZarrFormat) -> None: Test the `Group.__iter__` method. """ - group = Group.create(store, zarr_format=zarr_format) + group = Group.from_store(store, zarr_format=zarr_format) with pytest.raises(NotImplementedError): [x for x in group] # type: ignore @@ -277,7 +279,7 @@ def test_group_len(store: Store, zarr_format: ZarrFormat) -> None: Test the `Group.__len__` method. """ - group = Group.create(store, zarr_format=zarr_format) + group = Group.from_store(store, zarr_format=zarr_format) with pytest.raises(NotImplementedError): len(group) # type: ignore @@ -286,7 +288,7 @@ def test_group_setitem(store: Store, zarr_format: ZarrFormat) -> None: """ Test the `Group.__setitem__` method. """ - group = Group.create(store, zarr_format=zarr_format) + group = Group.from_store(store, zarr_format=zarr_format) with pytest.raises(NotImplementedError): group["key"] = 10 @@ -295,40 +297,34 @@ def test_group_contains(store: Store, zarr_format: ZarrFormat) -> None: """ Test the `Group.__contains__` method """ - group = Group.create(store, zarr_format=zarr_format) + group = Group.from_store(store, zarr_format=zarr_format) assert "foo" not in group _ = group.create_group(name="foo") assert "foo" in group -def test_group_subgroups(store: Store, zarr_format: ZarrFormat) -> None: - """ - Test the behavior of `Group` methods for accessing subgroups, namely `Group.group_keys` and `Group.groups` - """ - group = Group.create(store, zarr_format=zarr_format) - keys = ("foo", "bar") - subgroups_expected = tuple(group.create_group(k) for k in keys) - # create a sub-array as well - _ = group.create_array("array", shape=(10,)) - subgroups_observed = group.groups() - assert set(group.group_keys()) == set(keys) - assert len(subgroups_observed) == len(subgroups_expected) - assert all(a in subgroups_observed for a in subgroups_expected) +def test_group_child_iterators(store: Store, zarr_format: ZarrFormat): + group = Group.from_store(store, zarr_format=zarr_format) + expected_group_keys = ["g0", "g1"] + expected_group_values = [group.create_group(name=name) for name in expected_group_keys] + expected_groups = list(zip(expected_group_keys, expected_group_values, strict=False)) + expected_group_values[0].create_group("subgroup") + expected_group_values[0].create_array("subarray", shape=(1,)) -def test_group_subarrays(store: Store, zarr_format: ZarrFormat) -> None: - """ - Test the behavior of `Group` methods for accessing subgroups, namely `Group.group_keys` and `Group.groups` - """ - group = Group.create(store, zarr_format=zarr_format) - keys = ("foo", "bar") - subarrays_expected = tuple(group.create_array(k, shape=(10,)) for k in keys) - # create a sub-group as well - _ = group.create_group("group") - subarrays_observed = group.arrays() - assert set(group.array_keys()) == set(keys) - assert len(subarrays_observed) == len(subarrays_expected) - assert all(a in subarrays_observed for a in subarrays_expected) + expected_array_keys = ["a0", "a1"] + expected_array_values = [ + group.create_array(name=name, shape=(1,)) for name in expected_array_keys + ] + expected_arrays = list(zip(expected_array_keys, expected_array_values, strict=False)) + + assert sorted(group.groups(), key=lambda x: x[0]) == expected_groups + assert sorted(group.group_keys()) == expected_group_keys + assert sorted(group.group_values(), key=lambda x: x.name) == expected_group_values + + assert sorted(group.arrays(), key=lambda x: x[0]) == expected_arrays + assert sorted(group.array_keys()) == expected_array_keys + assert sorted(group.array_values(), key=lambda x: x.name) == expected_array_values def test_group_update_attributes(store: Store, zarr_format: ZarrFormat) -> None: @@ -336,7 +332,7 @@ def test_group_update_attributes(store: Store, zarr_format: ZarrFormat) -> None: Test the behavior of `Group.update_attributes` """ attrs = {"foo": 100} - group = Group.create(store, zarr_format=zarr_format, attributes=attrs) + group = Group.from_store(store, zarr_format=zarr_format, attributes=attrs) assert group.attrs == attrs new_attrs = {"bar": 100} new_group = group.update_attributes(new_attrs) @@ -348,7 +344,7 @@ async def test_group_update_attributes_async(store: Store, zarr_format: ZarrForm Test the behavior of `Group.update_attributes_async` """ attrs = {"foo": 100} - group = Group.create(store, zarr_format=zarr_format, attributes=attrs) + group = Group.from_store(store, zarr_format=zarr_format, attributes=attrs) assert group.attrs == attrs new_attrs = {"bar": 100} new_group = await group.update_attributes_async(new_attrs) @@ -363,9 +359,9 @@ def test_group_create_array( method: Literal["create_array", "array"], ) -> None: """ - Test `Group.create_array` + Test `Group.from_store` """ - group = Group.create(store, zarr_format=zarr_format) + group = Group.from_store(store, zarr_format=zarr_format) shape = (10, 10) dtype = "uint8" data = np.arange(np.prod(shape)).reshape(shape).astype(dtype) @@ -390,6 +386,73 @@ def test_group_create_array( assert np.array_equal(array[:], data) +def test_group_array_creation( + store: Store, + zarr_format: ZarrFormat, +): + group = Group.from_store(store, zarr_format=zarr_format) + shape = (10, 10) + empty_array = group.empty(name="empty", shape=shape) + assert isinstance(empty_array, Array) + assert empty_array.fill_value == 0 + assert empty_array.shape == shape + assert empty_array.store_path.store == store + + empty_like_array = group.empty_like(name="empty_like", prototype=empty_array) + assert isinstance(empty_like_array, Array) + assert empty_like_array.fill_value == 0 + assert empty_like_array.shape == shape + assert empty_like_array.store_path.store == store + + empty_array_bool = group.empty(name="empty_bool", shape=shape, dtype=np.dtype("bool")) + assert isinstance(empty_array_bool, Array) + assert not empty_array_bool.fill_value + assert empty_array_bool.shape == shape + assert empty_array_bool.store_path.store == store + + empty_like_array_bool = group.empty_like(name="empty_like_bool", prototype=empty_array_bool) + assert isinstance(empty_like_array_bool, Array) + assert not empty_like_array_bool.fill_value + assert empty_like_array_bool.shape == shape + assert empty_like_array_bool.store_path.store == store + + zeros_array = group.zeros(name="zeros", shape=shape) + assert isinstance(zeros_array, Array) + assert zeros_array.fill_value == 0 + assert zeros_array.shape == shape + assert zeros_array.store_path.store == store + + zeros_like_array = group.zeros_like(name="zeros_like", prototype=zeros_array) + assert isinstance(zeros_like_array, Array) + assert zeros_like_array.fill_value == 0 + assert zeros_like_array.shape == shape + assert zeros_like_array.store_path.store == store + + ones_array = group.ones(name="ones", shape=shape) + assert isinstance(ones_array, Array) + assert ones_array.fill_value == 1 + assert ones_array.shape == shape + assert ones_array.store_path.store == store + + ones_like_array = group.ones_like(name="ones_like", prototype=ones_array) + assert isinstance(ones_like_array, Array) + assert ones_like_array.fill_value == 1 + assert ones_like_array.shape == shape + assert ones_like_array.store_path.store == store + + full_array = group.full(name="full", shape=shape, fill_value=42) + assert isinstance(full_array, Array) + assert full_array.fill_value == 42 + assert full_array.shape == shape + assert full_array.store_path.store == store + + full_like_array = group.full_like(name="full_like", prototype=full_array, fill_value=43) + assert isinstance(full_like_array, Array) + assert full_like_array.fill_value == 43 + assert full_like_array.shape == shape + assert full_like_array.store_path.store == store + + @pytest.mark.parametrize("store", ("local", "memory", "zip"), indirect=["store"]) @pytest.mark.parametrize("zarr_format", (2, 3)) @pytest.mark.parametrize("exists_ok", [True, False]) @@ -404,7 +467,7 @@ def test_group_creation_existing_node( Check that an existing array or group is handled as expected during group creation. """ spath = StorePath(store) - group = Group.create(spath, zarr_format=zarr_format) + group = Group.from_store(spath, zarr_format=zarr_format) expected_exception: type[ContainsArrayError] | type[ContainsGroupError] attributes: dict[str, JSON] = {"old": True} @@ -420,7 +483,7 @@ def test_group_creation_existing_node( new_attributes = {"new": True} if exists_ok: - node_new = Group.create( + node_new = Group.from_store( spath / "extant", attributes=new_attributes, zarr_format=zarr_format, @@ -429,7 +492,7 @@ def test_group_creation_existing_node( assert node_new.attrs == new_attributes else: with pytest.raises(expected_exception): - node_new = Group.create( + node_new = Group.from_store( spath / "extant", attributes=new_attributes, zarr_format=zarr_format, @@ -443,11 +506,11 @@ async def test_asyncgroup_create( zarr_format: ZarrFormat, ) -> None: """ - Test that `AsyncGroup.create` works as expected. + Test that `AsyncGroup.from_store` works as expected. """ spath = StorePath(store=store) attributes = {"foo": 100} - agroup = await AsyncGroup.create( + agroup = await AsyncGroup.from_store( store, attributes=attributes, exists_ok=exists_ok, @@ -459,7 +522,7 @@ async def test_asyncgroup_create( if not exists_ok: with pytest.raises(ContainsGroupError): - agroup = await AsyncGroup.create( + agroup = await AsyncGroup.from_store( spath, attributes=attributes, exists_ok=exists_ok, @@ -471,7 +534,7 @@ async def test_asyncgroup_create( spath / collision_name, shape=(10,), dtype="uint8", zarr_format=zarr_format ) with pytest.raises(ContainsArrayError): - _ = await AsyncGroup.create( + _ = await AsyncGroup.from_store( StorePath(store=store) / collision_name, attributes=attributes, exists_ok=exists_ok, @@ -481,13 +544,13 @@ async def test_asyncgroup_create( async def test_asyncgroup_attrs(store: Store, zarr_format: ZarrFormat) -> None: attributes = {"foo": 100} - agroup = await AsyncGroup.create(store, zarr_format=zarr_format, attributes=attributes) + agroup = await AsyncGroup.from_store(store, zarr_format=zarr_format, attributes=attributes) assert agroup.attrs == agroup.metadata.attributes == attributes async def test_asyncgroup_info(store: Store, zarr_format: ZarrFormat) -> None: - agroup = await AsyncGroup.create( # noqa + agroup = await AsyncGroup.from_store( # noqa store, zarr_format=zarr_format, ) @@ -503,7 +566,7 @@ async def test_asyncgroup_open( Create an `AsyncGroup`, then ensure that we can open it using `AsyncGroup.open` """ attributes = {"foo": 100} - group_w = await AsyncGroup.create( + group_w = await AsyncGroup.from_store( store=store, attributes=attributes, exists_ok=False, @@ -520,7 +583,7 @@ async def test_asyncgroup_open_wrong_format( store: Store, zarr_format: ZarrFormat, ) -> None: - _ = await AsyncGroup.create(store=store, exists_ok=False, zarr_format=zarr_format) + _ = await AsyncGroup.from_store(store=store, exists_ok=False, zarr_format=zarr_format) zarr_format_wrong: ZarrFormat # try opening with the wrong zarr format if zarr_format == 3: @@ -563,7 +626,7 @@ async def test_asyncgroup_getitem(store: Store, zarr_format: ZarrFormat) -> None Create an `AsyncGroup`, then create members of that group, and ensure that we can access those members via the `AsyncGroup.getitem` method. """ - agroup = await AsyncGroup.create(store=store, zarr_format=zarr_format) + agroup = await AsyncGroup.from_store(store=store, zarr_format=zarr_format) array_name = "sub_array" sub_array = await agroup.create_array( @@ -584,7 +647,7 @@ async def test_asyncgroup_delitem(store: Store, zarr_format: ZarrFormat) -> None if not store.supports_deletes: pytest.skip("store does not support deletes") - agroup = await AsyncGroup.create(store=store, zarr_format=zarr_format) + agroup = await AsyncGroup.from_store(store=store, zarr_format=zarr_format) array_name = "sub_array" _ = await agroup.create_array( name=array_name, shape=(10,), dtype="uint8", chunk_shape=(2,), attributes={"foo": 100} @@ -616,7 +679,7 @@ async def test_asyncgroup_create_group( store: Store, zarr_format: ZarrFormat, ) -> None: - agroup = await AsyncGroup.create(store=store, zarr_format=zarr_format) + agroup = await AsyncGroup.from_store(store=store, zarr_format=zarr_format) sub_node_path = "sub_group" attributes = {"foo": 999} subnode = await agroup.create_group(name=sub_node_path, attributes=attributes) @@ -636,11 +699,11 @@ async def test_asyncgroup_create_array( specified in create_array are present on the resulting array. """ - agroup = await AsyncGroup.create(store=store, zarr_format=zarr_format) + agroup = await AsyncGroup.from_store(store=store, zarr_format=zarr_format) if not exists_ok: with pytest.raises(ContainsGroupError): - agroup = await AsyncGroup.create(store=store, zarr_format=zarr_format) + agroup = await AsyncGroup.from_store(store=store, zarr_format=zarr_format) shape = (10,) dtype = "uint8" @@ -673,7 +736,7 @@ async def test_asyncgroup_update_attributes(store: Store, zarr_format: ZarrForma """ attributes_old = {"foo": 10} attributes_new = {"baz": "new"} - agroup = await AsyncGroup.create( + agroup = await AsyncGroup.from_store( store=store, zarr_format=zarr_format, attributes=attributes_old ) @@ -684,7 +747,7 @@ async def test_asyncgroup_update_attributes(store: Store, zarr_format: ZarrForma @pytest.mark.parametrize("store", ("local",), indirect=["store"]) @pytest.mark.parametrize("zarr_format", (2, 3)) async def test_serializable_async_group(store: LocalStore, zarr_format: ZarrFormat) -> None: - expected = await AsyncGroup.create( + expected = await AsyncGroup.from_store( store=store, attributes={"foo": 999}, zarr_format=zarr_format ) p = pickle.dumps(expected) @@ -695,7 +758,7 @@ async def test_serializable_async_group(store: LocalStore, zarr_format: ZarrForm @pytest.mark.parametrize("store", ("local",), indirect=["store"]) @pytest.mark.parametrize("zarr_format", (2, 3)) def test_serializable_sync_group(store: LocalStore, zarr_format: ZarrFormat) -> None: - expected = Group.create(store=store, attributes={"foo": 999}, zarr_format=zarr_format) + expected = Group.from_store(store=store, attributes={"foo": 999}, zarr_format=zarr_format) p = pickle.dumps(expected) actual = pickle.loads(p) @@ -756,7 +819,7 @@ async def test_group_members_async(store: LocalStore | MemoryStore) -> None: async def test_require_group(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: - root = await AsyncGroup.create(store=store, zarr_format=zarr_format) + root = await AsyncGroup.from_store(store=store, zarr_format=zarr_format) # create foo group _ = await root.create_group("foo", attributes={"foo": 100}) @@ -784,7 +847,7 @@ async def test_require_group(store: LocalStore | MemoryStore, zarr_format: ZarrF async def test_require_groups(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: - root = await AsyncGroup.create(store=store, zarr_format=zarr_format) + root = await AsyncGroup.from_store(store=store, zarr_format=zarr_format) # create foo group _ = await root.create_group("foo", attributes={"foo": 100}) # create bar group @@ -805,7 +868,7 @@ async def test_require_groups(store: LocalStore | MemoryStore, zarr_format: Zarr async def test_create_dataset(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: - root = await AsyncGroup.create(store=store, zarr_format=zarr_format) + root = await AsyncGroup.from_store(store=store, zarr_format=zarr_format) with pytest.warns(DeprecationWarning): foo = await root.create_dataset("foo", shape=(10,), dtype="uint8") assert foo.shape == (10,) @@ -819,7 +882,7 @@ async def test_create_dataset(store: LocalStore | MemoryStore, zarr_format: Zarr async def test_require_array(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: - root = await AsyncGroup.create(store=store, zarr_format=zarr_format) + root = await AsyncGroup.from_store(store=store, zarr_format=zarr_format) foo1 = await root.require_array("foo", shape=(10,), dtype="i8", attributes={"foo": 101}) assert foo1.attrs == {"foo": 101} foo2 = await root.require_array("foo", shape=(10,), dtype="i8") diff --git a/tests/v3/test_indexing.py b/tests/v3/test_indexing.py index efb11f36a1..8b509f93d1 100644 --- a/tests/v3/test_indexing.py +++ b/tests/v3/test_indexing.py @@ -12,6 +12,10 @@ import zarr from zarr.core.buffer import BufferPrototype, default_buffer_prototype from zarr.core.indexing import ( + BasicSelection, + CoordinateSelection, + OrthogonalSelection, + Selection, make_slice_selection, normalize_integer_selection, oindex, @@ -23,13 +27,15 @@ from zarr.store.memory import MemoryStore if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import AsyncGenerator + from zarr.core.array import Array + from zarr.core.buffer.core import Buffer from zarr.core.common import ChunkCoords @pytest.fixture -async def store() -> Iterator[StorePath]: +async def store() -> AsyncGenerator[StorePath]: yield StorePath(await MemoryStore.open(mode="w")) @@ -50,18 +56,25 @@ def zarr_array_from_numpy_array( class CountingDict(MemoryStore): + counter: Counter[tuple[str, str]] + @classmethod async def open(cls) -> CountingDict: store = await super().open(mode="w") store.counter = Counter() return store - async def get(self, key, prototype: BufferPrototype, byte_range=None): + async def get( + self, + key: str, + prototype: BufferPrototype, + byte_range: tuple[int | None, int | None] | None = None, + ) -> Buffer | None: key_suffix = "/".join(key.split("/")[1:]) self.counter["__getitem__", key_suffix] += 1 return await super().get(key, prototype, byte_range) - async def set(self, key, value, byte_range=None): + async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: key_suffix = "/".join(key.split("/")[1:]) self.counter["__setitem__", key_suffix] += 1 return await super().set(key, value, byte_range) @@ -167,7 +180,7 @@ def test_get_basic_selection_0d(store: StorePath, use_out: bool, value: Any, dty # assert_array_equal(a[["foo", "bar"]], c) -basic_selections_1d = [ +basic_selections_1d: list[BasicSelection] = [ # single value 42, -1, @@ -241,7 +254,9 @@ def test_get_basic_selection_0d(store: StorePath, use_out: bool, value: Any, dty ] -def _test_get_basic_selection(a, z, selection) -> None: +def _test_get_basic_selection( + a: npt.NDArray[Any] | Array, z: Array, selection: BasicSelection +) -> None: expect = a[selection] actual = z.get_basic_selection(selection) assert_array_equal(expect, actual) @@ -265,17 +280,17 @@ def test_get_basic_selection_1d(store: StorePath) -> None: for selection in basic_selections_1d: _test_get_basic_selection(a, z, selection) - for selection in basic_selections_1d_bad: + for selection_bad in basic_selections_1d_bad: with pytest.raises(IndexError): - z.get_basic_selection(selection) + z.get_basic_selection(selection_bad) # type: ignore[arg-type] with pytest.raises(IndexError): - z[selection] + z[selection_bad] # type: ignore[index] with pytest.raises(IndexError): z.get_basic_selection([1, 0]) # type: ignore[arg-type] -basic_selections_2d = [ +basic_selections_2d: list[BasicSelection] = [ # single row 42, -1, @@ -340,9 +355,9 @@ def test_get_basic_selection_2d(store: StorePath) -> None: [0, 1], (slice(None), [0, 1]), ] - for selection in bad_selections: + for selection_bad in bad_selections: with pytest.raises(IndexError): - z.get_basic_selection(selection) + z.get_basic_selection(selection_bad) # type: ignore[arg-type] # check fallback on fancy indexing fancy_selection = ([0, 1], [0, 1]) np.testing.assert_array_equal(z[fancy_selection], [0, 11]) @@ -389,7 +404,7 @@ def test_fancy_indexing_fallback_on_get_setitem(store: StorePath) -> None: ], ) def test_orthogonal_indexing_fallback_on_getitem_2d( - store: StorePath, index, expected_result + store: StorePath, index: Selection, expected_result: npt.ArrayLike ) -> None: """ Tests the orthogonal indexing fallback on __getitem__ for a 2D matrix. @@ -407,6 +422,9 @@ def test_orthogonal_indexing_fallback_on_getitem_2d( np.testing.assert_array_equal(z[index], expected_result) +Index = list[int] | tuple[slice | int | list[int], ...] + + @pytest.mark.parametrize( "index,expected_result", [ @@ -421,7 +439,7 @@ def test_orthogonal_indexing_fallback_on_getitem_2d( ], ) def test_orthogonal_indexing_fallback_on_getitem_3d( - store: StorePath, index, expected_result + store: StorePath, index: Selection, expected_result: npt.ArrayLike ) -> None: """ Tests the orthogonal indexing fallback on __getitem__ for a 3D matrix. @@ -461,7 +479,7 @@ def test_orthogonal_indexing_fallback_on_getitem_3d( ], ) def test_orthogonal_indexing_fallback_on_setitem_2d( - store: StorePath, index, expected_result + store: StorePath, index: Selection, expected_result: npt.ArrayLike ) -> None: """ Tests the orthogonal indexing fallback on __setitem__ for a 3D matrix. @@ -485,9 +503,9 @@ def test_fancy_indexing_doesnt_mix_with_implicit_slicing(store: StorePath) -> No with pytest.raises(IndexError): np.testing.assert_array_equal(z2[[1, 2, 3], [1, 2, 3]], 0) with pytest.raises(IndexError): - z2[..., [1, 2, 3]] = 2 + z2[..., [1, 2, 3]] = 2 # type: ignore[index] with pytest.raises(IndexError): - np.testing.assert_array_equal(z2[..., [1, 2, 3]], 0) + np.testing.assert_array_equal(z2[..., [1, 2, 3]], 0) # type: ignore[index] @pytest.mark.parametrize( @@ -532,7 +550,9 @@ def test_set_basic_selection_0d( # arr_z[..., "foo", "bar"] = v[["foo", "bar"]] -def _test_get_orthogonal_selection(a, z, selection) -> None: +def _test_get_orthogonal_selection( + a: npt.NDArray[Any], z: Array, selection: OrthogonalSelection +) -> None: expect = oindex(a, selection) actual = z.get_orthogonal_selection(selection) assert_array_equal(expect, actual) @@ -558,7 +578,8 @@ def test_get_orthogonal_selection_1d_bool(store: StorePath) -> None: with pytest.raises(IndexError): z.oindex[np.zeros(2000, dtype=bool)] # too long with pytest.raises(IndexError): - z.oindex[[[True, False], [False, True]]] # too many dimensions + # too many dimensions + z.oindex[[[True, False], [False, True]]] # type: ignore[index] # noinspection PyStatementEffect @@ -594,14 +615,16 @@ def test_get_orthogonal_selection_1d_int(store: StorePath) -> None: [-(a.shape[0] + 1)], # out of bounds [[2, 4], [6, 8]], # too many dimensions ] - for selection in bad_selections: + for bad_selection in bad_selections: with pytest.raises(IndexError): - z.get_orthogonal_selection(selection) + z.get_orthogonal_selection(bad_selection) # type: ignore[arg-type] with pytest.raises(IndexError): - z.oindex[selection] + z.oindex[bad_selection] # type: ignore[index] -def _test_get_orthogonal_selection_2d(a, z, ix0, ix1) -> None: +def _test_get_orthogonal_selection_2d( + a: npt.NDArray[Any], z: Array, ix0: npt.NDArray[np.bool], ix1: npt.NDArray[np.bool] +) -> None: selections = [ # index both axes with array (ix0, ix1), @@ -651,17 +674,23 @@ def test_get_orthogonal_selection_2d(store: StorePath) -> None: ix1 = ix1[::-1] _test_get_orthogonal_selection_2d(a, z, ix0, ix1) - for selection in basic_selections_2d: - _test_get_orthogonal_selection(a, z, selection) + for selection_2d in basic_selections_2d: + _test_get_orthogonal_selection(a, z, selection_2d) - for selection in basic_selections_2d_bad: + for selection_2d_bad in basic_selections_2d_bad: with pytest.raises(IndexError): - z.get_orthogonal_selection(selection) + z.get_orthogonal_selection(selection_2d_bad) # type: ignore[arg-type] with pytest.raises(IndexError): - z.oindex[selection] + z.oindex[selection_2d_bad] # type: ignore[index] -def _test_get_orthogonal_selection_3d(a, z, ix0, ix1, ix2) -> None: +def _test_get_orthogonal_selection_3d( + a: npt.NDArray, + z: Array, + ix0: npt.NDArray[np.bool], + ix1: npt.NDArray[np.bool], + ix2: npt.NDArray[np.bool], +) -> None: selections = [ # single value (84, 42, 4), @@ -738,7 +767,9 @@ def test_orthogonal_indexing_edge_cases(store: StorePath) -> None: assert_array_equal(expect, actual) -def _test_set_orthogonal_selection(v, a, z, selection) -> None: +def _test_set_orthogonal_selection( + v: npt.NDArray[np.int_], a: npt.NDArray[Any], z: Array, selection: OrthogonalSelection +) -> None: for value in 42, oindex(v, selection), oindex(v, selection).tolist(): if isinstance(value, list) and value == []: # skip these cases as cannot preserve all dimensions @@ -782,7 +813,13 @@ def test_set_orthogonal_selection_1d(store: StorePath) -> None: _test_set_orthogonal_selection(v, a, z, selection) -def _test_set_orthogonal_selection_2d(v, a, z, ix0, ix1) -> None: +def _test_set_orthogonal_selection_2d( + v: npt.NDArray[np.int_], + a: npt.NDArray[np.int_], + z: Array, + ix0: npt.NDArray[np.bool], + ix1: npt.NDArray[np.bool], +) -> None: selections = [ # index both axes with array (ix0, ix1), @@ -825,7 +862,14 @@ def test_set_orthogonal_selection_2d(store: StorePath) -> None: _test_set_orthogonal_selection(v, a, z, selection) -def _test_set_orthogonal_selection_3d(v, a, z, ix0, ix1, ix2) -> None: +def _test_set_orthogonal_selection_3d( + v: npt.NDArray[np.int_], + a: npt.NDArray[np.int_], + z: Array, + ix0: npt.NDArray[np.bool], + ix1: npt.NDArray[np.bool], + ix2: npt.NDArray[np.bool], +) -> None: selections = ( # single value (84, 42, 4), @@ -906,7 +950,9 @@ def test_orthogonal_indexing_fallback_on_get_setitem(store: StorePath) -> None: np.testing.assert_array_equal(z2[:], [0, 1, 1, 1, 0]) -def _test_get_coordinate_selection(a, z, selection) -> None: +def _test_get_coordinate_selection( + a: npt.NDArray, z: Array, selection: CoordinateSelection +) -> None: expect = a[selection] actual = z.get_coordinate_selection(selection) assert_array_equal(expect, actual) @@ -967,9 +1013,9 @@ def test_get_coordinate_selection_1d(store: StorePath) -> None: ] for selection in bad_selections: with pytest.raises(IndexError): - z.get_coordinate_selection(selection) + z.get_coordinate_selection(selection) # type: ignore[arg-type] with pytest.raises(IndexError): - z.vindex[selection] + z.vindex[selection] # type: ignore[index] def test_get_coordinate_selection_2d(store: StorePath) -> None: @@ -978,6 +1024,8 @@ def test_get_coordinate_selection_2d(store: StorePath) -> None: z = zarr_array_from_numpy_array(store, a, chunk_shape=(300, 3)) np.random.seed(42) + ix0: npt.ArrayLike + ix1: npt.ArrayLike # test with different degrees of sparseness for p in 2, 0.5, 0.1, 0.01: n = int(a.size * p) @@ -1014,19 +1062,21 @@ def test_get_coordinate_selection_2d(store: StorePath) -> None: with pytest.raises(IndexError): selection = slice(5, 15), [1, 2, 3] - z.get_coordinate_selection(selection) + z.get_coordinate_selection(selection) # type:ignore[arg-type] with pytest.raises(IndexError): selection = [1, 2, 3], slice(5, 15) - z.get_coordinate_selection(selection) + z.get_coordinate_selection(selection) # type:ignore[arg-type] with pytest.raises(IndexError): selection = Ellipsis, [1, 2, 3] - z.get_coordinate_selection(selection) + z.get_coordinate_selection(selection) # type:ignore[arg-type] with pytest.raises(IndexError): selection = Ellipsis - z.get_coordinate_selection(selection) + z.get_coordinate_selection(selection) # type:ignore[arg-type] -def _test_set_coordinate_selection(v, a, z, selection) -> None: +def _test_set_coordinate_selection( + v: npt.NDArray, a: npt.NDArray, z: Array, selection: CoordinateSelection +) -> None: for value in 42, v[selection], v[selection].tolist(): # setup expectation a[:] = 0 @@ -1060,9 +1110,9 @@ def test_set_coordinate_selection_1d(store: StorePath) -> None: for selection in coordinate_selections_1d_bad: with pytest.raises(IndexError): - z.set_coordinate_selection(selection, 42) + z.set_coordinate_selection(selection, 42) # type:ignore[arg-type] with pytest.raises(IndexError): - z.vindex[selection] = 42 + z.vindex[selection] = 42 # type:ignore[index] def test_set_coordinate_selection_2d(store: StorePath) -> None: @@ -1096,7 +1146,12 @@ def test_set_coordinate_selection_2d(store: StorePath) -> None: _test_set_coordinate_selection(v, a, z, (ix0, ix1)) -def _test_get_block_selection(a, z, selection, expected_idx) -> None: +def _test_get_block_selection( + a: npt.NDArray[Any], + z: Array, + selection: BasicSelection, + expected_idx: slice | tuple[slice, ...], +) -> None: expect = a[expected_idx] actual = z.get_block_selection(selection) assert_array_equal(expect, actual) @@ -1104,7 +1159,7 @@ def _test_get_block_selection(a, z, selection, expected_idx) -> None: assert_array_equal(expect, actual) -block_selections_1d = [ +block_selections_1d: list[BasicSelection] = [ # test single item 0, 5, @@ -1119,7 +1174,7 @@ def _test_get_block_selection(a, z, selection, expected_idx) -> None: slice(None), # Full slice ] -block_selections_1d_array_projection = [ +block_selections_1d_array_projection: list[slice] = [ # test single item slice(100), slice(500, 600), @@ -1163,14 +1218,14 @@ def test_get_block_selection_1d(store: StorePath) -> None: -(z.metadata.chunk_grid.get_nchunks(z.shape) + 1), # out of bounds ] - for selection in bad_selections: + for selection_bad in bad_selections: with pytest.raises(IndexError): - z.get_block_selection(selection) + z.get_block_selection(selection_bad) # type:ignore[arg-type] with pytest.raises(IndexError): - z.blocks[selection] + z.blocks[selection_bad] # type:ignore[index] -block_selections_2d = [ +block_selections_2d: list[BasicSelection] = [ # test single item (0, 0), (1, 2), @@ -1185,7 +1240,7 @@ def test_get_block_selection_1d(store: StorePath) -> None: (slice(None), slice(None)), # Full slice ] -block_selections_2d_array_projection = [ +block_selections_2d_array_projection: list[tuple[slice, slice]] = [ # test single item (slice(300), slice(3)), (slice(300, 600), slice(6, 9)), @@ -1223,7 +1278,11 @@ def test_get_block_selection_2d(store: StorePath) -> None: def _test_set_block_selection( - v: np.ndarray, a: np.ndarray, z: zarr.Array, selection, expected_idx + v: npt.NDArray[Any], + a: npt.NDArray[Any], + z: zarr.Array, + selection: BasicSelection, + expected_idx: slice, ) -> None: for value in 42, v[expected_idx], v[expected_idx].tolist(): # setup expectation @@ -1250,11 +1309,11 @@ def test_set_block_selection_1d(store: StorePath) -> None: ): _test_set_block_selection(v, a, z, selection, expected_idx) - for selection in block_selections_1d_bad: + for selection_bad in block_selections_1d_bad: with pytest.raises(IndexError): - z.set_block_selection(selection, 42) + z.set_block_selection(selection_bad, 42) # type:ignore[arg-type] with pytest.raises(IndexError): - z.blocks[selection] = 42 + z.blocks[selection_bad] = 42 # type:ignore[index] def test_set_block_selection_2d(store: StorePath) -> None: @@ -1279,7 +1338,7 @@ def test_set_block_selection_2d(store: StorePath) -> None: z.set_block_selection(selection, 42) -def _test_get_mask_selection(a, z, selection) -> None: +def _test_get_mask_selection(a: npt.NDArray[Any], z: Array, selection: npt.NDArray) -> None: expect = a[selection] actual = z.get_mask_selection(selection) assert_array_equal(expect, actual) @@ -1324,9 +1383,9 @@ def test_get_mask_selection_1d(store: StorePath) -> None: ] for selection in bad_selections: with pytest.raises(IndexError): - z.get_mask_selection(selection) + z.get_mask_selection(selection) # type: ignore[arg-type] with pytest.raises(IndexError): - z.vindex[selection] + z.vindex[selection] # type:ignore[index] # noinspection PyStatementEffect @@ -1350,7 +1409,9 @@ def test_get_mask_selection_2d(store: StorePath) -> None: z.vindex[[True, False]] # wrong no. dimensions -def _test_set_mask_selection(v, a, z, selection) -> None: +def _test_set_mask_selection( + v: npt.NDArray, a: npt.NDArray, z: Array, selection: npt.NDArray +) -> None: a[:] = 0 z[:] = 0 a[selection] = v[selection] @@ -1378,9 +1439,9 @@ def test_set_mask_selection_1d(store: StorePath) -> None: for selection in mask_selections_1d_bad: with pytest.raises(IndexError): - z.set_mask_selection(selection, 42) + z.set_mask_selection(selection, 42) # type: ignore[arg-type] with pytest.raises(IndexError): - z.vindex[selection] = 42 + z.vindex[selection] = 42 # type: ignore[index] def test_set_mask_selection_2d(store: StorePath) -> None: @@ -1413,7 +1474,7 @@ def test_get_selection_out(store: StorePath) -> None: assert_array_equal(expect, out.as_numpy_array()[:]) with pytest.raises(TypeError): - z.get_basic_selection(Ellipsis, out=[]) + z.get_basic_selection(Ellipsis, out=[]) # type: ignore[arg-type] # orthogonal selections a = np.arange(10000, dtype=int).reshape(1000, 10) @@ -1467,11 +1528,13 @@ def test_get_selection_out(store: StorePath) -> None: @pytest.mark.xfail(reason="fields are not supported in v3") def test_get_selections_with_fields(store: StorePath) -> None: - a = [("aaa", 1, 4.2), ("bbb", 2, 8.4), ("ccc", 3, 12.6)] - a = np.array(a, dtype=[("foo", "S3"), ("bar", "i4"), ("baz", "f8")]) + a = np.array( + [("aaa", 1, 4.2), ("bbb", 2, 8.4), ("ccc", 3, 12.6)], + dtype=[("foo", "S3"), ("bar", "i4"), ("baz", "f8")], + ) z = zarr_array_from_numpy_array(store, a, chunk_shape=(2,)) - fields_fixture = [ + fields_fixture: list[str | list[str]] = [ "foo", ["foo"], ["foo", "bar"], @@ -1568,17 +1631,19 @@ def test_get_selections_with_fields(store: StorePath) -> None: with pytest.raises(IndexError): z.get_basic_selection(Ellipsis, fields=["notafield"]) with pytest.raises(IndexError): - z.get_basic_selection(Ellipsis, fields=slice(None)) + z.get_basic_selection(Ellipsis, fields=slice(None)) # type: ignore[arg-type] @pytest.mark.xfail(reason="fields are not supported in v3") def test_set_selections_with_fields(store: StorePath) -> None: - v = [("aaa", 1, 4.2), ("bbb", 2, 8.4), ("ccc", 3, 12.6)] - v = np.array(v, dtype=[("foo", "S3"), ("bar", "i4"), ("baz", "f8")]) + v = np.array( + [("aaa", 1, 4.2), ("bbb", 2, 8.4), ("ccc", 3, 12.6)], + dtype=[("foo", "S3"), ("bar", "i4"), ("baz", "f8")], + ) a = np.empty_like(v) z = zarr_array_from_numpy_array(store, v, chunk_shape=(2,)) - fields_fixture = [ + fields_fixture: list[str | list[str]] = [ "foo", [], ["foo"], @@ -1597,11 +1662,11 @@ def test_set_selections_with_fields(store: StorePath) -> None: with pytest.raises(IndexError): z.set_basic_selection(Ellipsis, v, fields=fields) with pytest.raises(IndexError): - z.set_orthogonal_selection([0, 2], v, fields=fields) + z.set_orthogonal_selection([0, 2], v, fields=fields) # type: ignore[arg-type] with pytest.raises(IndexError): z.set_coordinate_selection([0, 2], v, fields=fields) with pytest.raises(IndexError): - z.set_mask_selection([True, False, True], v, fields=fields) + z.set_mask_selection([True, False, True], v, fields=fields) # type: ignore[arg-type] else: if isinstance(fields, list) and len(fields) == 1: @@ -1694,7 +1759,9 @@ def test_numpy_int_indexing(store: StorePath) -> None: ), ], ) -async def test_accessed_chunks(shape, chunks, ops) -> None: +async def test_accessed_chunks( + shape: tuple[int, ...], chunks: tuple[int, ...], ops: list[tuple[str, tuple[slice, ...]]] +) -> None: # Test that only the required chunks are accessed during basic selection operations # shape: array shape # chunks: chunk size @@ -1767,7 +1834,7 @@ async def test_accessed_chunks(shape, chunks, ops) -> None: [[100, 200, 300], [4, 5, 6]], ], ) -def test_indexing_equals_numpy(store, selection) -> None: +def test_indexing_equals_numpy(store: StorePath, selection: Selection) -> None: a = np.arange(10000, dtype=int).reshape(1000, 10) z = zarr_array_from_numpy_array(store, a, chunk_shape=(300, 3)) # note: in python 3.10 a[*selection] is not valid unpacking syntax @@ -1785,7 +1852,9 @@ def test_indexing_equals_numpy(store, selection) -> None: [np.full(1000, True), [True, False] * 5], ], ) -def test_orthogonal_bool_indexing_like_numpy_ix(store, selection) -> None: +def test_orthogonal_bool_indexing_like_numpy_ix( + store: StorePath, selection: list[npt.ArrayLike] +) -> None: a = np.arange(10000, dtype=int).reshape(1000, 10) z = zarr_array_from_numpy_array(store, a, chunk_shape=(300, 3)) expected = a[np.ix_(*selection)] diff --git a/tests/v3/test_metadata/test_v3.py b/tests/v3/test_metadata/test_v3.py index 0a545dfb9d..d4cf0c73e3 100644 --- a/tests/v3/test_metadata/test_v3.py +++ b/tests/v3/test_metadata/test_v3.py @@ -19,7 +19,12 @@ import numpy as np import pytest -from zarr.core.metadata.v3 import parse_dimension_names, parse_fill_value, parse_zarr_format +from zarr.core.metadata.v3 import ( + parse_dimension_names, + parse_dtype, + parse_fill_value, + parse_zarr_format, +) bool_dtypes = ("bool",) @@ -76,14 +81,34 @@ def test_parse_auto_fill_value(dtype_str: str) -> None: assert parse_fill_value(fill_value, dtype) == dtype.type(0) -@pytest.mark.parametrize("fill_value", [0, 1.11, False, True]) -@pytest.mark.parametrize("dtype_str", dtypes) +@pytest.mark.parametrize( + "fill_value,dtype_str", + [ + (True, "bool"), + (False, "bool"), + (-8, "int8"), + (0, "int16"), + (1e10, "uint64"), + (-999, "float32"), + (1e32, "float64"), + (float("NaN"), "float64"), + (np.nan, "float64"), + (np.inf, "float64"), + (-1 * np.inf, "float64"), + (0j, "complex64"), + ], +) def test_parse_fill_value_valid(fill_value: Any, dtype_str: str) -> None: """ Test that parse_fill_value(fill_value, dtype) casts fill_value to the given dtype. """ dtype = np.dtype(dtype_str) - assert parse_fill_value(fill_value, dtype) == dtype.type(fill_value) + parsed = parse_fill_value(fill_value, dtype) + + if np.isnan(fill_value): + assert np.isnan(parsed) + else: + assert parsed == dtype.type(fill_value) @pytest.mark.parametrize("fill_value", ["not a valid value"]) @@ -138,8 +163,7 @@ def test_parse_fill_value_invalid_type(fill_value: Any, dtype_str: str) -> None: This test excludes bool because the bool constructor takes anything. """ dtype = np.dtype(dtype_str) - match = "must be" - with pytest.raises(TypeError, match=match): + with pytest.raises(ValueError, match=r"fill value .* is not valid for dtype .*"): parse_fill_value(fill_value, dtype) @@ -234,22 +258,87 @@ def test_metadata_to_dict( assert observed == expected -@pytest.mark.parametrize("fill_value", [-1, 0, 1, 2932897]) -@pytest.mark.parametrize("precision", ["ns", "D"]) -async def test_datetime_metadata(fill_value: int, precision: str) -> None: +# @pytest.mark.parametrize("fill_value", [-1, 0, 1, 2932897]) +# @pytest.mark.parametrize("precision", ["ns", "D"]) +# async def test_datetime_metadata(fill_value: int, precision: str) -> None: +# metadata_dict = { +# "zarr_format": 3, +# "node_type": "array", +# "shape": (1,), +# "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}}, +# "data_type": f" None: metadata_dict = { "zarr_format": 3, "node_type": "array", "shape": (1,), "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}}, - "data_type": f" None: + metadata_dict = { + "zarr_format": 3, + "node_type": "array", + "shape": (1,), + "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}}, + "data_type": data_type, + "chunk_key_encoding": {"name": "default", "separator": "."}, + "codecs": (), + "fill_value": fill_value, # this is not a valid fill value for uint8 + } + with pytest.raises(ValueError, match=r"fill value .* is not valid for dtype .*"): + ArrayV3Metadata.from_dict(metadata_dict) + + +@pytest.mark.parametrize("fill_value", [("NaN"), "Infinity", "-Infinity"]) +async def test_special_float_fill_values(fill_value: str) -> None: + metadata_dict = { + "zarr_format": 3, + "node_type": "array", + "shape": (1,), + "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}}, + "data_type": "float64", + "chunk_key_encoding": {"name": "default", "separator": "."}, + "codecs": (), + "fill_value": fill_value, # this is not a valid fill value for uint8 + } + m = ArrayV3Metadata.from_dict(metadata_dict) + d = json.loads(m.to_buffer_dict(default_buffer_prototype())["zarr.json"].to_bytes()) + assert m.fill_value is not None + if fill_value == "NaN": + assert np.isnan(m.fill_value) + assert d["fill_value"] == "NaN" + elif fill_value == "Infinity": + assert np.isposinf(m.fill_value) + assert d["fill_value"] == "Infinity" + elif fill_value == "-Infinity": + assert np.isneginf(m.fill_value) + assert d["fill_value"] == "-Infinity" diff --git a/tests/v3/test_store/test_core.py b/tests/v3/test_store/test_core.py index c65d91f9d0..f401491127 100644 --- a/tests/v3/test_store/test_core.py +++ b/tests/v3/test_store/test_core.py @@ -1,10 +1,12 @@ +import tempfile from pathlib import Path import pytest -from zarr.store.common import make_store_path +from zarr.store.common import StoreLike, StorePath, make_store_path from zarr.store.local import LocalStore from zarr.store.memory import MemoryStore +from zarr.store.remote import RemoteStore async def test_make_store_path(tmpdir: str) -> None: @@ -34,3 +36,32 @@ async def test_make_store_path(tmpdir: str) -> None: with pytest.raises(TypeError): await make_store_path(1) # type: ignore[arg-type] + + +async def test_make_store_path_fsspec(monkeypatch) -> None: + import fsspec.implementations.memory + + monkeypatch.setattr(fsspec.implementations.memory.MemoryFileSystem, "async_impl", True) + store_path = await make_store_path("memory://") + assert isinstance(store_path.store, RemoteStore) + + +@pytest.mark.parametrize( + "store_like", + [ + None, + str(tempfile.TemporaryDirectory()), + Path(tempfile.TemporaryDirectory().name), + StorePath(store=MemoryStore(store_dict={}, mode="w"), path="/"), + MemoryStore(store_dict={}, mode="w"), + {}, + ], +) +async def test_make_store_path_storage_options_raises(store_like: StoreLike) -> None: + with pytest.raises(TypeError, match="storage_options"): + await make_store_path(store_like, storage_options={"foo": "bar"}, mode="w") + + +async def test_unsupported() -> None: + with pytest.raises(TypeError, match="Unsupported type for store_like: 'int'"): + await make_store_path(1) # type: ignore[arg-type] diff --git a/tests/v3/test_store/test_local.py b/tests/v3/test_store/test_local.py index 59cae22de3..5f1dde3fcc 100644 --- a/tests/v3/test_store/test_local.py +++ b/tests/v3/test_store/test_local.py @@ -35,6 +35,3 @@ def test_store_supports_partial_writes(self, store: LocalStore) -> None: def test_store_supports_listing(self, store: LocalStore) -> None: assert store.supports_listing - - def test_list_prefix(self, store: LocalStore) -> None: - assert True diff --git a/tests/v3/test_store/test_memory.py b/tests/v3/test_store/test_memory.py index 470383dfb6..2498cdc24a 100644 --- a/tests/v3/test_store/test_memory.py +++ b/tests/v3/test_store/test_memory.py @@ -1,7 +1,5 @@ from __future__ import annotations -import pickle - import pytest from zarr.core.buffer import Buffer, cpu, gpu @@ -48,16 +46,6 @@ def test_store_supports_partial_writes(self, store: MemoryStore) -> None: def test_list_prefix(self, store: MemoryStore) -> None: assert True - def test_serizalizable_store(self, store: MemoryStore) -> None: - with pytest.raises(NotImplementedError): - store.__getstate__() - - with pytest.raises(NotImplementedError): - store.__setstate__({}) - - with pytest.raises(NotImplementedError): - pickle.dumps(store) - @gpu_test class TestGpuMemoryStore(StoreTests[GpuMemoryStore, gpu.Buffer]): @@ -92,13 +80,3 @@ def test_store_supports_partial_writes(self, store: GpuMemoryStore) -> None: def test_list_prefix(self, store: GpuMemoryStore) -> None: assert True - - def test_serizalizable_store(self, store: MemoryStore) -> None: - with pytest.raises(NotImplementedError): - store.__getstate__() - - with pytest.raises(NotImplementedError): - store.__setstate__({}) - - with pytest.raises(NotImplementedError): - pickle.dumps(store) diff --git a/tests/v3/test_store/test_remote.py b/tests/v3/test_store/test_remote.py index afa991209f..495a5e5c4f 100644 --- a/tests/v3/test_store/test_remote.py +++ b/tests/v3/test_store/test_remote.py @@ -1,16 +1,27 @@ +from __future__ import annotations + +import json import os from collections.abc import Generator +from typing import TYPE_CHECKING -import botocore.client import fsspec import pytest +from botocore.session import Session from upath import UPath +import zarr.api.asynchronous from zarr.core.buffer import Buffer, cpu, default_buffer_prototype -from zarr.core.sync import sync +from zarr.core.sync import _collect_aiterator, sync from zarr.store import RemoteStore from zarr.testing.store import StoreTests +if TYPE_CHECKING: + from collections.abc import Generator + + import botocore.client + + s3fs = pytest.importorskip("s3fs") requests = pytest.importorskip("requests") moto_server = pytest.importorskip("moto.moto_server.threaded_moto_server") @@ -40,8 +51,6 @@ def s3_base() -> Generator[None, None, None]: def get_boto3_client() -> botocore.client.BaseClient: - from botocore.session import Session - # NB: we use the sync botocore client for setup session = Session() return session.create_client("s3", endpoint_url=endpoint_url) @@ -84,10 +93,12 @@ async def alist(it): async def test_basic() -> None: - store = await RemoteStore.open( - f"s3://{test_bucket_name}", mode="w", endpoint_url=endpoint_url, anon=False + store = RemoteStore.from_url( + f"s3://{test_bucket_name}", + mode="w", + storage_options=dict(endpoint_url=endpoint_url, anon=False), ) - assert not await alist(store.list()) + assert await _collect_aiterator(store.list()) == () assert not await store.exists("foo") data = b"hello" await store.set("foo", cpu.Buffer.from_bytes(data)) @@ -103,51 +114,33 @@ class TestRemoteStoreS3(StoreTests[RemoteStore, cpu.Buffer]): store_cls = RemoteStore buffer_cls = cpu.Buffer - @pytest.fixture(scope="function", params=("use_upath", "use_str")) + @pytest.fixture(scope="function") def store_kwargs(self, request) -> dict[str, str | bool]: - url = f"s3://{test_bucket_name}" - anon = False - mode = "r+" - if request.param == "use_upath": - return {"url": UPath(url, endpoint_url=endpoint_url, anon=anon), "mode": mode} - elif request.param == "use_str": - return {"url": url, "mode": mode, "anon": anon, "endpoint_url": endpoint_url} - - raise AssertionError + fs, path = fsspec.url_to_fs( + f"s3://{test_bucket_name}", endpoint_url=endpoint_url, anon=False + ) + return {"fs": fs, "path": path, "mode": "r+"} @pytest.fixture(scope="function") def store(self, store_kwargs: dict[str, str | bool]) -> RemoteStore: - url = store_kwargs["url"] - mode = store_kwargs["mode"] - if isinstance(url, UPath): - out = self.store_cls(url=url, mode=mode) - else: - endpoint_url = store_kwargs["endpoint_url"] - out = self.store_cls(url=url, asynchronous=True, mode=mode, endpoint_url=endpoint_url) - return out + return self.store_cls(**store_kwargs) def get(self, store: RemoteStore, key: str) -> Buffer: # make a new, synchronous instance of the filesystem because this test is run in sync code - fs, _ = fsspec.url_to_fs( - url=store._url, - asynchronous=False, - anon=store._fs.anon, - endpoint_url=store._fs.endpoint_url, + new_fs = fsspec.filesystem( + "s3", endpoint_url=store.fs.endpoint_url, anon=store.fs.anon, asynchronous=False ) - return self.buffer_cls.from_bytes(fs.cat(f"{store.path}/{key}")) + return self.buffer_cls.from_bytes(new_fs.cat(f"{store.path}/{key}")) def set(self, store: RemoteStore, key: str, value: Buffer) -> None: # make a new, synchronous instance of the filesystem because this test is run in sync code - fs, _ = fsspec.url_to_fs( - url=store._url, - asynchronous=False, - anon=store._fs.anon, - endpoint_url=store._fs.endpoint_url, + new_fs = fsspec.filesystem( + "s3", endpoint_url=store.fs.endpoint_url, anon=store.fs.anon, asynchronous=False ) - fs.write_bytes(f"{store.path}/{key}", value.to_bytes()) + new_fs.write_bytes(f"{store.path}/{key}", value.to_bytes()) def test_store_repr(self, store: RemoteStore) -> None: - assert str(store) == f"s3://{test_bucket_name}" + assert str(store) == "" def test_store_supports_writes(self, store: RemoteStore) -> None: assert True @@ -158,3 +151,47 @@ def test_store_supports_partial_writes(self, store: RemoteStore) -> None: def test_store_supports_listing(self, store: RemoteStore) -> None: assert True + + async def test_remote_store_from_uri( + self, store: RemoteStore, store_kwargs: dict[str, str | bool] + ): + storage_options = { + "endpoint_url": endpoint_url, + "anon": False, + } + + meta = {"attributes": {"key": "value"}, "zarr_format": 3, "node_type": "group"} + + await store.set( + "zarr.json", + self.buffer_cls.from_bytes(json.dumps(meta).encode()), + ) + group = await zarr.api.asynchronous.open_group( + store=f"s3://{test_bucket_name}", storage_options=storage_options + ) + assert dict(group.attrs) == {"key": "value"} + + meta["attributes"]["key"] = "value-2" + await store.set( + "directory-2/zarr.json", + self.buffer_cls.from_bytes(json.dumps(meta).encode()), + ) + group = await zarr.api.asynchronous.open_group( + store=f"s3://{test_bucket_name}/directory-2", storage_options=storage_options + ) + assert dict(group.attrs) == {"key": "value-2"} + + meta["attributes"]["key"] = "value-3" + await store.set( + "directory-3/zarr.json", + self.buffer_cls.from_bytes(json.dumps(meta).encode()), + ) + group = await zarr.api.asynchronous.open_group( + store=f"s3://{test_bucket_name}", path="directory-3", storage_options=storage_options + ) + assert dict(group.attrs) == {"key": "value-3"} + + def test_from_upath(self) -> None: + path = UPath(f"s3://{test_bucket_name}", endpoint_url=endpoint_url, anon=False) + result = RemoteStore.from_upath(path) + assert result.fs.endpoint_url == endpoint_url diff --git a/tests/v3/test_sync.py b/tests/v3/test_sync.py index 22834747e7..864c9e01cb 100644 --- a/tests/v3/test_sync.py +++ b/tests/v3/test_sync.py @@ -4,7 +4,9 @@ import pytest +import zarr from zarr.core.sync import SyncError, SyncMixin, _get_lock, _get_loop, sync +from zarr.store.memory import MemoryStore @pytest.fixture(params=[True, False]) @@ -121,3 +123,9 @@ def bar(self) -> list[int]: foo = SyncFoo(async_foo) assert foo.foo() == "foo" assert foo.bar() == list(range(10)) + + +def test_open_positional_args_deprecate(): + store = MemoryStore({}, mode="w") + with pytest.warns(FutureWarning, match="pass"): + zarr.open(store, "w", shape=(1,))