From 7e2be57e3d3c176730bb59aa23c77cd6005f4c85 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Tue, 8 Oct 2024 15:07:20 -0400 Subject: [PATCH] Add `string` and `bytes` dtypes plus `vlen-utf8` and `vlen-bytes` codecs (#2036) * add legacy vlen-utf8 codec * got it working again * got strings working; broke everything else * change v3.metadata.data_type type * fixed tests * satisfy mypy for tests * make strings work * add missing module * store -> storage * rename module * add vlen bytes * fix type assertions in test * much better validation of fill value * retype parse_fill_value * tests pass but not mypy * attempted to change parse_fill_value typing * restore DEFAULT_DTYPE * fixup * docstring * update test * add better DataType tests * more progress on typing; still not passing mypy * fix typing yay! * make types work with numpy <, 2 * Apply suggestions from code review Co-authored-by: Joe Hamman * Apply suggestions from code review Co-authored-by: Joe Hamman * apply Joe's suggestions * add missing module * make _STRING_DTYPE private to try to make sphinx happy --------- Co-authored-by: Davis Bennett Co-authored-by: Tom Augspurger Co-authored-by: Joe Hamman --- src/zarr/codecs/__init__.py | 21 +++ src/zarr/codecs/vlen_utf8.py | 117 +++++++++++++++++ src/zarr/core/array.py | 8 +- src/zarr/core/buffer/core.py | 6 +- src/zarr/core/config.py | 2 + src/zarr/core/metadata/v3.py | 205 ++++++++++++++++++++---------- src/zarr/core/strings.py | 87 +++++++++++++ tests/test_strings.py | 35 +++++ tests/v3/test_array.py | 38 +++++- tests/v3/test_codecs/test_vlen.py | 95 ++++++++++++++ tests/v3/test_config.py | 2 + tests/v3/test_metadata/test_v3.py | 62 +++++---- 12 files changed, 584 insertions(+), 94 deletions(-) create mode 100644 src/zarr/codecs/vlen_utf8.py create mode 100644 src/zarr/core/strings.py create mode 100644 tests/test_strings.py create mode 100644 tests/v3/test_codecs/test_vlen.py diff --git a/src/zarr/codecs/__init__.py b/src/zarr/codecs/__init__.py index 9394284319..ac647d7863 100644 --- a/src/zarr/codecs/__init__.py +++ b/src/zarr/codecs/__init__.py @@ -1,5 +1,10 @@ from __future__ import annotations +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + import numpy as np + from zarr.codecs.blosc import BloscCname, BloscCodec, BloscShuffle from zarr.codecs.bytes import BytesCodec, Endian from zarr.codecs.crc32c_ import Crc32cCodec @@ -7,7 +12,9 @@ from zarr.codecs.pipeline import BatchedCodecPipeline from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation from zarr.codecs.transpose import TransposeCodec +from zarr.codecs.vlen_utf8 import VLenBytesCodec, VLenUTF8Codec from zarr.codecs.zstd import ZstdCodec +from zarr.core.metadata.v3 import DataType __all__ = [ "BatchedCodecPipeline", @@ -21,5 +28,19 @@ "ShardingCodec", "ShardingCodecIndexLocation", "TransposeCodec", + "VLenUTF8Codec", + "VLenBytesCodec", "ZstdCodec", ] + + +def _get_default_array_bytes_codec( + np_dtype: np.dtype[Any], +) -> BytesCodec | VLenUTF8Codec | VLenBytesCodec: + dtype = DataType.from_numpy(np_dtype) + if dtype == DataType.string: + return VLenUTF8Codec() + elif dtype == DataType.bytes: + return VLenBytesCodec() + else: + return BytesCodec() diff --git a/src/zarr/codecs/vlen_utf8.py b/src/zarr/codecs/vlen_utf8.py new file mode 100644 index 0000000000..43544e0809 --- /dev/null +++ b/src/zarr/codecs/vlen_utf8.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import numpy as np +from numcodecs.vlen import VLenBytes, VLenUTF8 + +from zarr.abc.codec import ArrayBytesCodec +from zarr.core.buffer import Buffer, NDBuffer +from zarr.core.common import JSON, parse_named_configuration +from zarr.core.strings import cast_to_string_dtype +from zarr.registry import register_codec + +if TYPE_CHECKING: + from typing import Self + + from zarr.core.array_spec import ArraySpec + + +# can use a global because there are no parameters +_vlen_utf8_codec = VLenUTF8() +_vlen_bytes_codec = VLenBytes() + + +@dataclass(frozen=True) +class VLenUTF8Codec(ArrayBytesCodec): + @classmethod + def from_dict(cls, data: dict[str, JSON]) -> Self: + _, configuration_parsed = parse_named_configuration( + data, "vlen-utf8", require_configuration=False + ) + configuration_parsed = configuration_parsed or {} + return cls(**configuration_parsed) + + def to_dict(self) -> dict[str, JSON]: + return {"name": "vlen-utf8", "configuration": {}} + + def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: + return self + + async def _decode_single( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> NDBuffer: + assert isinstance(chunk_bytes, Buffer) + + raw_bytes = chunk_bytes.as_array_like() + decoded = _vlen_utf8_codec.decode(raw_bytes) + assert decoded.dtype == np.object_ + decoded.shape = chunk_spec.shape + # coming out of the code, we know this is safe, so don't issue a warning + as_string_dtype = cast_to_string_dtype(decoded, safe=True) + return chunk_spec.prototype.nd_buffer.from_numpy_array(as_string_dtype) + + async def _encode_single( + self, + chunk_array: NDBuffer, + chunk_spec: ArraySpec, + ) -> Buffer | None: + assert isinstance(chunk_array, NDBuffer) + return chunk_spec.prototype.buffer.from_bytes( + _vlen_utf8_codec.encode(chunk_array.as_numpy_array()) + ) + + def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: + # what is input_byte_length for an object dtype? + raise NotImplementedError("compute_encoded_size is not implemented for VLen codecs") + + +@dataclass(frozen=True) +class VLenBytesCodec(ArrayBytesCodec): + @classmethod + def from_dict(cls, data: dict[str, JSON]) -> Self: + _, configuration_parsed = parse_named_configuration( + data, "vlen-bytes", require_configuration=False + ) + configuration_parsed = configuration_parsed or {} + return cls(**configuration_parsed) + + def to_dict(self) -> dict[str, JSON]: + return {"name": "vlen-bytes", "configuration": {}} + + def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: + return self + + async def _decode_single( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> NDBuffer: + assert isinstance(chunk_bytes, Buffer) + + raw_bytes = chunk_bytes.as_array_like() + decoded = _vlen_bytes_codec.decode(raw_bytes) + assert decoded.dtype == np.object_ + decoded.shape = chunk_spec.shape + return chunk_spec.prototype.nd_buffer.from_numpy_array(decoded) + + async def _encode_single( + self, + chunk_array: NDBuffer, + chunk_spec: ArraySpec, + ) -> Buffer | None: + assert isinstance(chunk_array, NDBuffer) + return chunk_spec.prototype.buffer.from_bytes( + _vlen_bytes_codec.encode(chunk_array.as_numpy_array()) + ) + + def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: + # what is input_byte_length for an object dtype? + raise NotImplementedError("compute_encoded_size is not implemented for VLen codecs") + + +register_codec("vlen-utf8", VLenUTF8Codec) +register_codec("vlen-bytes", VLenBytesCodec) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 9a78297c6f..9f5591ce1e 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -11,7 +11,7 @@ from zarr._compat import _deprecate_positional_args from zarr.abc.store import Store, set_or_delete -from zarr.codecs import BytesCodec +from zarr.codecs import _get_default_array_bytes_codec from zarr.codecs._v2 import V2Compressor, V2Filters from zarr.core.attributes import Attributes from zarr.core.buffer import ( @@ -318,7 +318,11 @@ async def _create_v3( await ensure_no_existing_node(store_path, zarr_format=3) shape = parse_shapelike(shape) - codecs = list(codecs) if codecs is not None else [BytesCodec()] + codecs = ( + list(codecs) + if codecs is not None + else [_get_default_array_bytes_codec(np.dtype(dtype))] + ) if chunk_key_encoding is None: chunk_key_encoding = ("default", "/") diff --git a/src/zarr/core/buffer/core.py b/src/zarr/core/buffer/core.py index f2f81d9c51..1fbf58c618 100644 --- a/src/zarr/core/buffer/core.py +++ b/src/zarr/core/buffer/core.py @@ -313,8 +313,6 @@ class NDBuffer: """ def __init__(self, array: NDArrayLike) -> None: - # assert array.ndim > 0 - assert array.dtype != object self._data = array @classmethod @@ -467,9 +465,11 @@ def all_equal(self, other: Any, equal_nan: bool = True) -> bool: # Handle None fill_value for Zarr V2 return False # use array_equal to obtain equal_nan=True functionality + # Since fill-value is a scalar, isn't there a faster path than allocating a new array for fill value + # every single time we have to write data? _data, other = np.broadcast_arrays(self._data, other) return np.array_equal( - self._data, other, equal_nan=equal_nan if self._data.dtype.kind not in "US" else False + self._data, other, equal_nan=equal_nan if self._data.dtype.kind not in "USTO" else False ) def fill(self, value: Any) -> None: diff --git a/src/zarr/core/config.py b/src/zarr/core/config.py index 735755616f..3fe7d803d2 100644 --- a/src/zarr/core/config.py +++ b/src/zarr/core/config.py @@ -58,6 +58,8 @@ def reset(self) -> None: "crc32c": "zarr.codecs.crc32c_.Crc32cCodec", "sharding_indexed": "zarr.codecs.sharding.ShardingCodec", "transpose": "zarr.codecs.transpose.TransposeCodec", + "vlen-utf8": "zarr.codecs.vlen_utf8.VLenUTF8Codec", + "vlen-bytes": "zarr.codecs.vlen_utf8.VLenBytesCodec", }, "buffer": "zarr.core.buffer.cpu.Buffer", "ndbuffer": "zarr.core.buffer.cpu.NDBuffer", diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 692f778566..47c6106bfe 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, cast, overload +from typing import TYPE_CHECKING, overload if TYPE_CHECKING: from typing import Self @@ -14,7 +14,7 @@ from collections.abc import Iterable, Sequence from dataclasses import dataclass, field, replace from enum import Enum -from typing import Any, Literal +from typing import Any, Literal, cast import numcodecs.abc import numpy as np @@ -28,6 +28,7 @@ from zarr.core.common import ZARR_JSON, parse_named_configuration, parse_shapelike from zarr.core.config import config from zarr.core.metadata.common import ArrayMetadata, parse_attributes +from zarr.core.strings import _STRING_DTYPE as STRING_NP_DTYPE from zarr.registry import get_codec_class DEFAULT_DTYPE = "float64" @@ -63,6 +64,34 @@ def parse_codecs(data: object) -> tuple[Codec, ...]: return out +def validate_codecs(codecs: tuple[Codec, ...], dtype: DataType) -> None: + """Check that the codecs are valid for the given dtype""" + + # ensure that we have at least one ArrayBytesCodec + abcs: list[ArrayBytesCodec] = [] + for codec in codecs: + if isinstance(codec, ArrayBytesCodec): + abcs.append(codec) + if len(abcs) == 0: + raise ValueError("At least one ArrayBytesCodec is required.") + elif len(abcs) > 1: + raise ValueError("Only one ArrayBytesCodec is allowed.") + + abc = abcs[0] + + # we need to have special codecs if we are decoding vlen strings or bytestrings + # TODO: use codec ID instead of class name + codec_id = abc.__class__.__name__ + if dtype == DataType.string and not codec_id == "VLenUTF8Codec": + raise ValueError( + f"For string dtype, ArrayBytesCodec must be `VLenUTF8Codec`, got `{codec_id}`." + ) + if dtype == DataType.bytes and not codec_id == "VLenBytesCodec": + raise ValueError( + f"For bytes dtype, ArrayBytesCodec must be `VLenBytesCodec`, got `{codec_id}`." + ) + + def parse_dimension_names(data: object) -> tuple[str | None, ...] | None: if data is None: return data @@ -185,7 +214,12 @@ def __init__( chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid) chunk_key_encoding_parsed = ChunkKeyEncoding.from_dict(chunk_key_encoding) dimension_names_parsed = parse_dimension_names(dimension_names) - fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed.to_numpy()) + if fill_value is None: + fill_value = default_fill_value(data_type_parsed) + # we pass a string here rather than an enum to make mypy happy + fill_value_parsed = parse_fill_value( + fill_value, dtype=cast(ALL_DTYPES, data_type_parsed.value) + ) attributes_parsed = parse_attributes(attributes) codecs_parsed_partial = parse_codecs(codecs) storage_transformers_parsed = parse_storage_transformers(storage_transformers) @@ -198,6 +232,7 @@ def __init__( prototype=default_buffer_prototype(), # TODO: prototype is not needed here. ) codecs_parsed = [c.evolve_from_array_spec(array_spec) for c in codecs_parsed_partial] + validate_codecs(codecs_parsed_partial, data_type_parsed) object.__setattr__(self, "shape", shape_parsed) object.__setattr__(self, "data_type", data_type_parsed) @@ -297,72 +332,71 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self: return replace(self, attributes=attributes) +# enum Literals can't be used in typing, so we have to restate all of the V3 dtypes as types +# https://github.com/python/typing/issues/781 + +BOOL_DTYPE = Literal["bool"] BOOL = np.bool_ -BOOL_DTYPE = np.dtypes.BoolDType -INTEGER_DTYPE = ( - np.dtypes.Int8DType - | np.dtypes.Int16DType - | np.dtypes.Int32DType - | np.dtypes.Int64DType - | np.dtypes.UInt8DType - | np.dtypes.UInt16DType - | np.dtypes.UInt32DType - | np.dtypes.UInt64DType -) +INTEGER_DTYPE = Literal["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"] INTEGER = np.int8 | np.int16 | np.int32 | np.int64 | np.uint8 | np.uint16 | np.uint32 | np.uint64 -FLOAT_DTYPE = np.dtypes.Float16DType | np.dtypes.Float32DType | np.dtypes.Float64DType +FLOAT_DTYPE = Literal["float16", "float32", "float64"] FLOAT = np.float16 | np.float32 | np.float64 -COMPLEX_DTYPE = np.dtypes.Complex64DType | np.dtypes.Complex128DType +COMPLEX_DTYPE = Literal["complex64", "complex128"] COMPLEX = np.complex64 | np.complex128 +STRING_DTYPE = Literal["string"] +STRING = np.str_ +BYTES_DTYPE = Literal["bytes"] +BYTES = np.bytes_ + +ALL_DTYPES = BOOL_DTYPE | INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE | STRING_DTYPE | BYTES_DTYPE @overload def parse_fill_value( - fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool | None, + fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, dtype: BOOL_DTYPE, ) -> BOOL: ... @overload def parse_fill_value( - fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool | None, + fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, dtype: INTEGER_DTYPE, ) -> INTEGER: ... @overload def parse_fill_value( - fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool | None, + fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, dtype: FLOAT_DTYPE, ) -> FLOAT: ... @overload def parse_fill_value( - fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool | None, + fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, dtype: COMPLEX_DTYPE, ) -> COMPLEX: ... @overload def parse_fill_value( - fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool | None, - dtype: np.dtype[Any], -) -> Any: - # This dtype[Any] is unfortunately necessary right now. - # See https://github.com/zarr-developers/zarr-python/issues/2131#issuecomment-2318010899 - # for more details, but `dtype` here (which comes from `parse_dtype`) - # is np.dtype[Any]. - # - # If you want the specialized types rather than Any, you need to use `np.dtypes.` - # rather than np.dtypes() - ... + fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, + dtype: STRING_DTYPE, +) -> STRING: ... +@overload def parse_fill_value( - fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool | None, - dtype: BOOL_DTYPE | INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE | np.dtype[Any], -) -> BOOL | INTEGER | FLOAT | COMPLEX | Any: + fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, + dtype: BYTES_DTYPE, +) -> BYTES: ... + + +def parse_fill_value( + fill_value: Any, + dtype: ALL_DTYPES, +) -> Any: """ Parse `fill_value`, a potential fill value, into an instance of `dtype`, a data type. If `fill_value` is `None`, then this function will return the result of casting the value 0 @@ -376,29 +410,37 @@ def parse_fill_value( ---------- fill_value: Any A potential fill value. - dtype: BOOL_DTYPE | INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE - A numpy data type that models a data type defined in the Zarr V3 specification. + dtype: str + A valid Zarr V3 DataType. Returns ------- A scalar instance of `dtype` """ + data_type = DataType(dtype) if fill_value is None: - return dtype.type(0) + raise ValueError("Fill value cannot be None") + if data_type == DataType.string: + return np.str_(fill_value) + if data_type == DataType.bytes: + return np.bytes_(fill_value) + + # the rest are numeric types + np_dtype = cast(np.dtype[np.generic], data_type.to_numpy()) + if isinstance(fill_value, Sequence) and not isinstance(fill_value, str): - if dtype.type in (np.complex64, np.complex128): - dtype = cast(COMPLEX_DTYPE, dtype) + if data_type in (DataType.complex64, DataType.complex128): if len(fill_value) == 2: # complex datatypes serialize to JSON arrays with two elements - return dtype.type(complex(*fill_value)) + return np_dtype.type(complex(*fill_value)) else: msg = ( - f"Got an invalid fill value for complex data type {dtype}." + f"Got an invalid fill value for complex data type {data_type.value}." f"Expected a sequence with 2 elements, but {fill_value!r} has " f"length {len(fill_value)}." ) raise ValueError(msg) - msg = f"Cannot parse non-string sequence {fill_value!r} as a scalar with type {dtype}." + msg = f"Cannot parse non-string sequence {fill_value!r} as a scalar with type {data_type.value}." raise TypeError(msg) # Cast the fill_value to the given dtype @@ -409,27 +451,38 @@ def parse_fill_value( # fill_value != casted_value below. with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) - casted_value = np.dtype(dtype).type(fill_value) + casted_value = np.dtype(np_dtype).type(fill_value) except (ValueError, OverflowError, TypeError) as e: - raise ValueError(f"fill value {fill_value!r} is not valid for dtype {dtype}") from e + raise ValueError(f"fill value {fill_value!r} is not valid for dtype {data_type}") from e # Check if the value is still representable by the dtype if fill_value == "NaN" and np.isnan(casted_value): pass elif fill_value in ["Infinity", "-Infinity"] and not np.isfinite(casted_value): pass - elif dtype.kind in "cf": + elif np_dtype.kind in "cf": # float comparison is not exact, especially when dtype str | bytes | np.generic: + if dtype == DataType.string: + return "" + elif dtype == DataType.bytes: + return b"" + else: + np_dtype = dtype.to_numpy() + np_dtype = cast(np.dtype[np.generic], np_dtype) + return np_dtype.type(0) + + # For type checking _bool = bool @@ -449,9 +502,11 @@ class DataType(Enum): float64 = "float64" complex64 = "complex64" complex128 = "complex128" + string = "string" + bytes = "bytes" @property - def byte_count(self) -> int: + def byte_count(self) -> None | int: data_type_byte_counts = { DataType.bool: 1, DataType.int8: 1, @@ -468,12 +523,15 @@ def byte_count(self) -> int: DataType.complex64: 8, DataType.complex128: 16, } - return data_type_byte_counts[self] + try: + return data_type_byte_counts[self] + except KeyError: + # string and bytes have variable length + return None @property def has_endianness(self) -> _bool: - # This might change in the future, e.g. for a complex with 2 8-bit floats - return self.byte_count != 1 + return self.byte_count is not None and self.byte_count != 1 def to_numpy_shortname(self) -> str: data_type_to_numpy = { @@ -494,11 +552,26 @@ def to_numpy_shortname(self) -> str: } return data_type_to_numpy[self] - def to_numpy(self) -> np.dtype[Any]: - return np.dtype(self.to_numpy_shortname()) + def to_numpy(self) -> np.dtypes.StringDType | np.dtypes.ObjectDType | np.dtype[np.generic]: + # note: it is not possible to round trip DataType <-> np.dtype + # due to the fact that DataType.string and DataType.bytes both + # generally return np.dtype("O") from this function, even though + # they can originate as fixed-length types (e.g. " DataType: + if dtype.kind in "UT": + return DataType.string + elif dtype.kind == "S": + return DataType.bytes dtype_to_data_type = { "|b1": "bool", "bool": "bool", @@ -521,18 +594,20 @@ def from_numpy(cls, dtype: np.dtype[Any]) -> DataType: @classmethod def parse(cls, dtype: None | DataType | Any) -> DataType: if dtype is None: - # the default dtype return DataType[DEFAULT_DTYPE] if isinstance(dtype, DataType): return dtype - else: - try: - dtype = np.dtype(dtype) - except (ValueError, TypeError) as e: - raise ValueError(f"Invalid V3 data_type: {dtype}") from e - # check that this is a valid v3 data_type - try: - data_type = DataType.from_numpy(dtype) - except KeyError as e: - raise ValueError(f"Invalid V3 data_type: {dtype}") from e - return data_type + try: + return DataType(dtype) + except ValueError: + pass + try: + dtype = np.dtype(dtype) + except (ValueError, TypeError) as e: + raise ValueError(f"Invalid V3 data_type: {dtype}") from e + # check that this is a valid v3 data_type + try: + data_type = DataType.from_numpy(dtype) + except KeyError as e: + raise ValueError(f"Invalid V3 data_type: {dtype}") from e + return data_type diff --git a/src/zarr/core/strings.py b/src/zarr/core/strings.py new file mode 100644 index 0000000000..9ec391c04a --- /dev/null +++ b/src/zarr/core/strings.py @@ -0,0 +1,87 @@ +"""This module contains utilities for working with string arrays across +different versions of Numpy. +""" + +from typing import Any, Union, cast +from warnings import warn + +import numpy as np + +# _STRING_DTYPE is the in-memory datatype that will be used for V3 string arrays +# when reading data back from Zarr. +# Any valid string-like datatype should be fine for *setting* data. + +_STRING_DTYPE: Union["np.dtypes.StringDType", "np.dtypes.ObjectDType"] +_NUMPY_SUPPORTS_VLEN_STRING: bool + + +def cast_array( + data: np.ndarray[Any, np.dtype[Any]], +) -> np.ndarray[Any, Union["np.dtypes.StringDType", "np.dtypes.ObjectDType"]]: + raise NotImplementedError + + +try: + # this new vlen string dtype was added in NumPy 2.0 + _STRING_DTYPE = np.dtypes.StringDType() + _NUMPY_SUPPORTS_VLEN_STRING = True + + def cast_array( + data: np.ndarray[Any, np.dtype[Any]], + ) -> np.ndarray[Any, np.dtypes.StringDType | np.dtypes.ObjectDType]: + out = data.astype(_STRING_DTYPE, copy=False) + return cast(np.ndarray[Any, np.dtypes.StringDType], out) + +except AttributeError: + # if not available, we fall back on an object array of strings, as in Zarr < 3 + _STRING_DTYPE = np.dtypes.ObjectDType() + _NUMPY_SUPPORTS_VLEN_STRING = False + + def cast_array( + data: np.ndarray[Any, np.dtype[Any]], + ) -> np.ndarray[Any, Union["np.dtypes.StringDType", "np.dtypes.ObjectDType"]]: + out = data.astype(_STRING_DTYPE, copy=False) + return cast(np.ndarray[Any, np.dtypes.ObjectDType], out) + + +def cast_to_string_dtype( + data: np.ndarray[Any, np.dtype[Any]], safe: bool = False +) -> np.ndarray[Any, Union["np.dtypes.StringDType", "np.dtypes.ObjectDType"]]: + """Take any data and attempt to cast to to our preferred string dtype. + + data : np.ndarray + The data to cast + + safe : bool + If True, do not issue a warning if the data is cast from object to string dtype. + + """ + if np.issubdtype(data.dtype, np.str_): + # legacy fixed-width string type (e.g. "= 2.", + stacklevel=2, + ) + return cast_array(data) + raise ValueError(f"Cannot cast dtype {data.dtype} to string dtype") diff --git a/tests/test_strings.py b/tests/test_strings.py new file mode 100644 index 0000000000..dca0570a25 --- /dev/null +++ b/tests/test_strings.py @@ -0,0 +1,35 @@ +"""Tests for the strings module.""" + +import numpy as np +import pytest + +from zarr.core.strings import _NUMPY_SUPPORTS_VLEN_STRING, _STRING_DTYPE, cast_to_string_dtype + + +def test_string_defaults() -> None: + if _NUMPY_SUPPORTS_VLEN_STRING: + assert _STRING_DTYPE == np.dtypes.StringDType() + else: + assert _STRING_DTYPE == np.dtypes.ObjectDType() + + +def test_cast_to_string_dtype() -> None: + d1 = np.array(["a", "b", "c"]) + assert d1.dtype == np.dtype(" None: expected = sorted(keys) assert observed == expected + + +def test_default_fill_values() -> None: + a = Array.create(MemoryStore({}, mode="w"), shape=5, chunk_shape=5, dtype=" None: + with pytest.raises(ValueError, match="At least one ArrayBytesCodec is required."): + Array.create(MemoryStore({}, mode="w"), shape=5, chunk_shape=5, dtype=" None: + strings = ["hello", "world", "this", "is", "a", "test"] + data = np.array(strings, dtype=dtype).reshape((2, 3)) + + sp = StorePath(store, path="string") + a = Array.create( + sp, + shape=data.shape, + chunk_shape=data.shape, + dtype=data.dtype, + fill_value="", + codecs=codecs, + ) + assert isinstance(a.metadata, ArrayV3Metadata) # needed for mypy + + # should also work if input array is an object array, provided we explicitly specified + # a stringlike dtype when creating the Array + if as_object_array: + data = data.astype("O") + + a[:, :] = data + assert np.array_equal(data, a[:, :]) + assert a.metadata.data_type == DataType.string + assert a.dtype == expected_zarr_string_dtype + + # test round trip + b = Array.open(sp) + assert isinstance(b.metadata, ArrayV3Metadata) # needed for mypy + assert np.array_equal(data, b[:, :]) + assert b.metadata.data_type == DataType.string + assert a.dtype == expected_zarr_string_dtype + + +@pytest.mark.parametrize("store", ["memory", "local"], indirect=["store"]) +@pytest.mark.parametrize("as_object_array", [False, True]) +@pytest.mark.parametrize("codecs", [None, [VLenBytesCodec()], [VLenBytesCodec(), ZstdCodec()]]) +def test_vlen_bytes(store: Store, as_object_array: bool, codecs: None | list[Codec]) -> None: + bstrings = [b"hello", b"world", b"this", b"is", b"a", b"test"] + data = np.array(bstrings).reshape((2, 3)) + assert data.dtype == "|S5" + + sp = StorePath(store, path="string") + a = Array.create( + sp, + shape=data.shape, + chunk_shape=data.shape, + dtype=data.dtype, + fill_value=b"", + codecs=codecs, + ) + assert isinstance(a.metadata, ArrayV3Metadata) # needed for mypy + + # should also work if input array is an object array, provided we explicitly specified + # a bytesting-like dtype when creating the Array + if as_object_array: + data = data.astype("O") + a[:, :] = data + assert np.array_equal(data, a[:, :]) + assert a.metadata.data_type == DataType.bytes + assert a.dtype == "O" + + # test round trip + b = Array.open(sp) + assert isinstance(b.metadata, ArrayV3Metadata) # needed for mypy + assert np.array_equal(data, b[:, :]) + assert b.metadata.data_type == DataType.bytes + assert a.dtype == "O" diff --git a/tests/v3/test_config.py b/tests/v3/test_config.py index e324367b3d..2adc51aa57 100644 --- a/tests/v3/test_config.py +++ b/tests/v3/test_config.py @@ -58,6 +58,8 @@ def test_config_defaults_set() -> None: "crc32c": "zarr.codecs.crc32c_.Crc32cCodec", "sharding_indexed": "zarr.codecs.sharding.ShardingCodec", "transpose": "zarr.codecs.transpose.TransposeCodec", + "vlen-utf8": "zarr.codecs.vlen_utf8.VLenUTF8Codec", + "vlen-bytes": "zarr.codecs.vlen_utf8.VLenBytesCodec", }, } ] diff --git a/tests/v3/test_metadata/test_v3.py b/tests/v3/test_metadata/test_v3.py index 534ef61d09..b55b2d7a65 100644 --- a/tests/v3/test_metadata/test_v3.py +++ b/tests/v3/test_metadata/test_v3.py @@ -21,6 +21,7 @@ import pytest from zarr.core.metadata.v3 import ( + default_fill_value, parse_dimension_names, parse_fill_value, parse_zarr_format, @@ -46,8 +47,9 @@ ) complex_dtypes = ("complex64", "complex128") +vlen_dtypes = ("string", "bytes") -dtypes = (*bool_dtypes, *int_dtypes, *float_dtypes, *complex_dtypes) +dtypes = (*bool_dtypes, *int_dtypes, *float_dtypes, *complex_dtypes, *vlen_dtypes) @pytest.mark.parametrize("data", [None, 1, 2, 4, 5, "3"]) @@ -72,13 +74,18 @@ def parse_dimension_names_valid(data: Sequence[str] | None) -> None: @pytest.mark.parametrize("dtype_str", dtypes) -def test_parse_auto_fill_value(dtype_str: str) -> None: +def test_default_fill_value(dtype_str: str) -> None: """ Test that parse_fill_value(None, dtype) results in the 0 value for the given dtype. """ - dtype = np.dtype(dtype_str) - fill_value = None - assert parse_fill_value(fill_value, dtype) == dtype.type(0) + dtype = DataType(dtype_str) + fill_value = default_fill_value(dtype) + if dtype == DataType.string: + assert fill_value == "" + elif dtype == DataType.bytes: + assert fill_value == b"" + else: + assert fill_value == dtype.to_numpy().type(0) @pytest.mark.parametrize( @@ -102,13 +109,12 @@ def test_parse_fill_value_valid(fill_value: Any, dtype_str: str) -> None: """ Test that parse_fill_value(fill_value, dtype) casts fill_value to the given dtype. """ - dtype = np.dtype(dtype_str) - parsed = parse_fill_value(fill_value, dtype) + parsed = parse_fill_value(fill_value, dtype_str) if np.isnan(fill_value): assert np.isnan(parsed) else: - assert parsed == dtype.type(fill_value) + assert parsed == DataType(dtype_str).to_numpy().type(fill_value) @pytest.mark.parametrize("fill_value", ["not a valid value"]) @@ -118,9 +124,8 @@ def test_parse_fill_value_invalid_value(fill_value: Any, dtype_str: str) -> None Test that parse_fill_value(fill_value, dtype) raises ValueError for invalid values. This test excludes bool because the bool constructor takes anything. """ - dtype = np.dtype(dtype_str) with pytest.raises(ValueError): - parse_fill_value(fill_value, dtype) + parse_fill_value(fill_value, dtype_str) @pytest.mark.parametrize("fill_value", [[1.0, 0.0], [0, 1], complex(1, 1), np.complex64(0)]) @@ -130,12 +135,12 @@ def test_parse_fill_value_complex(fill_value: Any, dtype_str: str) -> None: Test that parse_fill_value(fill_value, dtype) correctly handles complex values represented as length-2 sequences """ - dtype = np.dtype(dtype_str) + dtype = DataType(dtype_str) if isinstance(fill_value, list): - expected = dtype.type(complex(*fill_value)) + expected = dtype.to_numpy().type(complex(*fill_value)) else: - expected = dtype.type(fill_value) - assert expected == parse_fill_value(fill_value, dtype) + expected = dtype.to_numpy().type(fill_value) + assert expected == parse_fill_value(fill_value, dtype_str) @pytest.mark.parametrize("fill_value", [[1.0, 0.0, 3.0], [0, 1, 3], [1]]) @@ -145,14 +150,13 @@ def test_parse_fill_value_complex_invalid(fill_value: Any, dtype_str: str) -> No Test that parse_fill_value(fill_value, dtype) correctly rejects sequences with length not equal to 2 """ - dtype = np.dtype(dtype_str) match = ( - f"Got an invalid fill value for complex data type {dtype}." + f"Got an invalid fill value for complex data type {dtype_str}." f"Expected a sequence with 2 elements, but {fill_value} has " f"length {len(fill_value)}." ) with pytest.raises(ValueError, match=re.escape(match)): - parse_fill_value(fill_value=fill_value, dtype=dtype) + parse_fill_value(fill_value=fill_value, dtype=dtype_str) @pytest.mark.parametrize("fill_value", [{"foo": 10}]) @@ -162,9 +166,8 @@ def test_parse_fill_value_invalid_type(fill_value: Any, dtype_str: str) -> None: Test that parse_fill_value(fill_value, dtype) raises TypeError for invalid non-sequential types. This test excludes bool because the bool constructor takes anything. """ - dtype = np.dtype(dtype_str) with pytest.raises(ValueError, match=r"fill value .* is not valid for dtype .*"): - parse_fill_value(fill_value, dtype) + parse_fill_value(fill_value, dtype_str) @pytest.mark.parametrize( @@ -183,10 +186,9 @@ def test_parse_fill_value_invalid_type_sequence(fill_value: Any, dtype_str: str) This test excludes bool because the bool constructor takes anything, and complex because complex values can be created from length-2 sequences. """ - dtype = np.dtype(dtype_str) - match = f"Cannot parse non-string sequence {fill_value} as a scalar with type {dtype}" + match = f"Cannot parse non-string sequence {fill_value} as a scalar with type {dtype_str}" with pytest.raises(TypeError, match=re.escape(match)): - parse_fill_value(fill_value, dtype) + parse_fill_value(fill_value, dtype_str) @pytest.mark.parametrize("chunk_grid", ["regular"]) @@ -337,7 +339,7 @@ async def test_special_float_fill_values(fill_value: str) -> None: "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}}, "data_type": "float64", "chunk_key_encoding": {"name": "default", "separator": "."}, - "codecs": (), + "codecs": [{"name": "bytes"}], "fill_value": fill_value, # this is not a valid fill value for uint8 } m = ArrayV3Metadata.from_dict(metadata_dict) @@ -352,3 +354,17 @@ async def test_special_float_fill_values(fill_value: str) -> None: elif fill_value == "-Infinity": assert np.isneginf(m.fill_value) assert d["fill_value"] == "-Infinity" + + +@pytest.mark.parametrize("dtype_str", dtypes) +def test_dtypes(dtype_str: str) -> None: + dt = DataType(dtype_str) + np_dtype = dt.to_numpy() + if dtype_str not in vlen_dtypes: + # we can round trip "normal" dtypes + assert dt == DataType.from_numpy(np_dtype) + assert dt.byte_count == np_dtype.itemsize + assert dt.has_endianness == (dt.byte_count > 1) + else: + # return type for vlen types may vary depending on numpy version + assert dt.byte_count is None