From f1b01ac7cb1d673314320c9427049ae62046dcf3 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Wed, 18 Sep 2024 13:51:44 -0700 Subject: [PATCH 1/7] fix: validate v3 dtypes when loading/creating v3 metadata --- src/zarr/core/array_spec.py | 4 +-- src/zarr/core/common.py | 7 ----- src/zarr/core/metadata/v2.py | 7 ++++- src/zarr/core/metadata/v3.py | 32 +++++++++++++++++++- src/zarr/testing/strategies.py | 4 ++- tests/v3/test_metadata/test_v3.py | 50 +++++++++++++++++++++++-------- 6 files changed, 79 insertions(+), 25 deletions(-) diff --git a/src/zarr/core/array_spec.py b/src/zarr/core/array_spec.py index e64a962bc3..82deaf6477 100644 --- a/src/zarr/core/array_spec.py +++ b/src/zarr/core/array_spec.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal -from zarr.core.common import parse_dtype, parse_fill_value, parse_order, parse_shapelike +from zarr.core.common import parse_fill_value, parse_order, parse_shapelike if TYPE_CHECKING: import numpy as np @@ -29,7 +29,7 @@ def __init__( prototype: BufferPrototype, ) -> None: shape_parsed = parse_shapelike(shape) - dtype_parsed = parse_dtype(dtype) + dtype_parsed = dtype # parsing is likely not needed here fill_value_parsed = parse_fill_value(fill_value) order_parsed = parse_order(order) diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index 906467005f..745b95beb4 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -19,8 +19,6 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Iterator -import numpy as np -import numpy.typing as npt ZARR_JSON = "zarr.json" ZARRAY_JSON = ".zarray" @@ -154,11 +152,6 @@ def parse_shapelike(data: int | Iterable[int]) -> tuple[int, ...]: return data_tuple -def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]: - # todo: real validation - return np.dtype(data) - - def parse_fill_value(data: Any) -> Any: # todo: real validation return data diff --git a/src/zarr/core/metadata/v2.py b/src/zarr/core/metadata/v2.py index af7821bea7..ee990800b3 100644 --- a/src/zarr/core/metadata/v2.py +++ b/src/zarr/core/metadata/v2.py @@ -21,7 +21,7 @@ from zarr.core.array_spec import ArraySpec from zarr.core.chunk_grids import RegularChunkGrid from zarr.core.chunk_key_encodings import parse_separator -from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_dtype, parse_shapelike +from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_shapelike from zarr.core.config import config, parse_indexing_order from zarr.core.metadata.common import ArrayMetadata, parse_attributes @@ -157,6 +157,11 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self: return replace(self, attributes=attributes) +def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]: + # todo: real validation + return np.dtype(data) + + def parse_zarr_format(data: object) -> Literal[2]: if data == 2: return 2 diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 10047cbb93..6d72531967 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -24,7 +24,7 @@ from zarr.core.buffer import default_buffer_prototype from zarr.core.chunk_grids import ChunkGrid, RegularChunkGrid from zarr.core.chunk_key_encodings import ChunkKeyEncoding -from zarr.core.common import ZARR_JSON, parse_dtype, parse_named_configuration, parse_shapelike +from zarr.core.common import ZARR_JSON, parse_named_configuration, parse_shapelike from zarr.core.config import config from zarr.core.metadata.common import ArrayMetadata, parse_attributes from zarr.registry import get_codec_class @@ -215,6 +215,10 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: # check that the node_type attribute is correct _ = parse_node_type_array(_data.pop("node_type")) + # check that the data_type attribute is valid + if _data["data_type"] not in DataType: + raise ValueError(f"Invalid V3 data_type: {_data['data_type']}") + # dimension_names key is optional, normalize missing to `None` _data["dimension_names"] = _data.pop("dimension_names", None) # attributes key is optional, normalize missing to `None` @@ -345,8 +349,11 @@ class DataType(Enum): uint16 = "uint16" uint32 = "uint32" uint64 = "uint64" + float16 = "float16" float32 = "float32" float64 = "float64" + complex64 = "complex64" + complex128 = "complex128" @property def byte_count(self) -> int: @@ -360,8 +367,11 @@ def byte_count(self) -> int: DataType.uint16: 2, DataType.uint32: 4, DataType.uint64: 8, + DataType.float16: 2, DataType.float32: 4, DataType.float64: 8, + DataType.complex64: 8, + DataType.complex128: 16, } return data_type_byte_counts[self] @@ -381,8 +391,11 @@ def to_numpy_shortname(self) -> str: DataType.uint16: "u2", DataType.uint32: "u4", DataType.uint64: "u8", + DataType.float16: "f2", DataType.float32: "f4", DataType.float64: "f8", + DataType.complex64: "c8", + DataType.complex128: "c16", } return data_type_to_numpy[self] @@ -399,7 +412,24 @@ def from_dtype(cls, dtype: np.dtype[Any]) -> DataType: " np.dtype[Any]: + try: + dtype = np.dtype(data) + except TypeError as e: + raise ValueError(f"Invalid V3 data_type: {data}") from e + # check that this is a valid v3 data_type + try: + _ = DataType.from_dtype(dtype) + except KeyError as e: + raise ValueError(f"Invalid V3 data_type: {dtype}") from e + + return dtype diff --git a/src/zarr/testing/strategies.py b/src/zarr/testing/strategies.py index 83de3d92ce..d59003c7ce 100644 --- a/src/zarr/testing/strategies.py +++ b/src/zarr/testing/strategies.py @@ -35,7 +35,9 @@ paths = st.lists(node_names, min_size=1).map(lambda x: "/".join(x)) | st.just("/") np_arrays = npst.arrays( # TODO: re-enable timedeltas once they are supported - dtype=npst.scalar_dtypes().filter(lambda x: x.kind != "m"), + dtype=npst.scalar_dtypes().filter( + lambda x: (x.kind not in ["m", "M"]) and (x.byteorder not in [">"]) + ), shape=npst.array_shapes(max_dims=4), ) stores = st.builds(MemoryStore, st.just({}), mode=st.just("w")) diff --git a/tests/v3/test_metadata/test_v3.py b/tests/v3/test_metadata/test_v3.py index 0a545dfb9d..a8f61f90ee 100644 --- a/tests/v3/test_metadata/test_v3.py +++ b/tests/v3/test_metadata/test_v3.py @@ -1,11 +1,9 @@ from __future__ import annotations -import json import re from typing import TYPE_CHECKING, Literal from zarr.codecs.bytes import BytesCodec -from zarr.core.buffer import default_buffer_prototype from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding from zarr.core.metadata.v3 import ArrayV3Metadata @@ -19,7 +17,12 @@ import numpy as np import pytest -from zarr.core.metadata.v3 import parse_dimension_names, parse_fill_value, parse_zarr_format +from zarr.core.metadata.v3 import ( + parse_dimension_names, + parse_dtype, + parse_fill_value, + parse_zarr_format, +) bool_dtypes = ("bool",) @@ -234,22 +237,43 @@ def test_metadata_to_dict( assert observed == expected -@pytest.mark.parametrize("fill_value", [-1, 0, 1, 2932897]) -@pytest.mark.parametrize("precision", ["ns", "D"]) -async def test_datetime_metadata(fill_value: int, precision: str) -> None: +# @pytest.mark.parametrize("fill_value", [-1, 0, 1, 2932897]) +# @pytest.mark.parametrize("precision", ["ns", "D"]) +# async def test_datetime_metadata(fill_value: int, precision: str) -> None: +# metadata_dict = { +# "zarr_format": 3, +# "node_type": "array", +# "shape": (1,), +# "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}}, +# "data_type": f" None: metadata_dict = { "zarr_format": 3, "node_type": "array", "shape": (1,), "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}}, - "data_type": f" Date: Thu, 19 Sep 2024 19:07:39 -0700 Subject: [PATCH 2/7] tests passing --- pyproject.toml | 1 + src/zarr/core/array_spec.py | 3 +-- src/zarr/core/metadata/v2.py | 22 +++++++++++++++++++--- src/zarr/testing/strategies.py | 5 ++++- 4 files changed, 25 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 886cd5a0bc..effcac83ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -273,6 +273,7 @@ filterwarnings = [ "ignore:PY_SSIZE_T_CLEAN will be required.*:DeprecationWarning", "ignore:The loop argument is deprecated since Python 3.8.*:DeprecationWarning", "ignore:Creating a zarr.buffer.gpu.*:UserWarning", + "ignore:Duplicate name:UserWarning", # from ZipFile ] markers = [ "gpu: mark a test as requiring CuPy and GPU" diff --git a/src/zarr/core/array_spec.py b/src/zarr/core/array_spec.py index 82deaf6477..1a251a0a4b 100644 --- a/src/zarr/core/array_spec.py +++ b/src/zarr/core/array_spec.py @@ -29,12 +29,11 @@ def __init__( prototype: BufferPrototype, ) -> None: shape_parsed = parse_shapelike(shape) - dtype_parsed = dtype # parsing is likely not needed here fill_value_parsed = parse_fill_value(fill_value) order_parsed = parse_order(order) object.__setattr__(self, "shape", shape_parsed) - object.__setattr__(self, "dtype", dtype_parsed) + object.__setattr__(self, "dtype", dtype) object.__setattr__(self, "fill_value", fill_value_parsed) object.__setattr__(self, "order", order_parsed) object.__setattr__(self, "prototype", prototype) diff --git a/src/zarr/core/metadata/v2.py b/src/zarr/core/metadata/v2.py index ee990800b3..34bdbb537f 100644 --- a/src/zarr/core/metadata/v2.py +++ b/src/zarr/core/metadata/v2.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Iterable +from enum import Enum from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -100,9 +101,24 @@ def _json_convert( else: return o.descr if np.isscalar(o): - # convert numpy scalar to python type, and pass - # python types through - return getattr(o, "item", lambda: o)() + out: Any + if hasattr(o, "dtype") and o.dtype.kind == "M" and hasattr(o, "view"): + # https://github.com/zarr-developers/zarr-python/issues/2119 + # `.item()` on a datetime type might or might not return an + # integer, depending on the value. + # Explicitly cast to an int first, and then grab .item() + out = o.view("i8").item() + else: + # convert numpy scalar to python type, and pass + # python types through + out = getattr(o, "item", lambda: o)() + if isinstance(out, complex): + # python complex types are not JSON serializable, so we use the + # serialization defined in the zarr v3 spec + return [out.real, out.imag] + return out + if isinstance(o, Enum): + return o.name raise TypeError zarray_dict = self.to_dict() diff --git a/src/zarr/testing/strategies.py b/src/zarr/testing/strategies.py index d59003c7ce..00f93026ab 100644 --- a/src/zarr/testing/strategies.py +++ b/src/zarr/testing/strategies.py @@ -42,6 +42,7 @@ ) stores = st.builds(MemoryStore, st.just({}), mode=st.just("w")) compressors = st.sampled_from([None, "default"]) +format = st.sampled_from([2, 3]) @st.composite # type: ignore[misc] @@ -71,12 +72,14 @@ def arrays( paths: st.SearchStrategy[None | str] = paths, array_names: st.SearchStrategy = array_names, attrs: st.SearchStrategy = attrs, + format: st.SearchStrategy = format, ) -> Array: store = draw(stores) nparray, chunks = draw(np_array_and_chunks(arrays=arrays)) path = draw(paths) name = draw(array_names) attributes = draw(attrs) + zarr_format = draw(format) # compressor = draw(compressors) # TODO: clean this up @@ -101,7 +104,7 @@ def arrays( expected_attrs = {} if attributes is None else attributes array_path = path + ("/" if not path.endswith("/") else "") + name - root = Group.create(store) + root = Group.create(store, zarr_format=zarr_format) fill_value_args: tuple[Any, ...] = tuple() if nparray.dtype.kind == "M": m = re.search(r"\[(.+)\]", nparray.dtype.str) From 44d310ba858090ecea794a834369a191a1add9b9 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Thu, 19 Sep 2024 19:27:27 -0700 Subject: [PATCH 3/7] check that fill value is valid for dtype --- src/zarr/core/metadata/v3.py | 12 +++++++++++- tests/v3/test_metadata/test_v3.py | 18 ++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 6d72531967..34a33103c6 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -332,7 +332,17 @@ def parse_fill_value( raise ValueError(msg) msg = f"Cannot parse non-string sequence {fill_value} as a scalar with type {dtype}." raise TypeError(msg) - return dtype.type(fill_value) # type: ignore[arg-type] + + # Cast the fill_value to the given dtype + try: + casted_value = np.dtype(dtype).type(fill_value) + except (ValueError, OverflowError, TypeError) as e: + raise ValueError(f"fill value {fill_value!r} is not valid for dtype {dtype}") from e + # Check if the value is still representable by the dtype + if fill_value != casted_value: + raise ValueError(f"fill value {fill_value!r} is not valid for dtype {dtype}") + + return casted_value # For type checking diff --git a/tests/v3/test_metadata/test_v3.py b/tests/v3/test_metadata/test_v3.py index a8f61f90ee..7bfa2e0192 100644 --- a/tests/v3/test_metadata/test_v3.py +++ b/tests/v3/test_metadata/test_v3.py @@ -277,3 +277,21 @@ async def test_invalid_dtype_raises() -> None: def test_parse_invalid_dtype_raises(data): with pytest.raises(ValueError, match=r"Invalid V3 data_type"): parse_dtype(data) + + +@pytest.mark.parametrize( + "data_type,fill_value", [("uint8", -1), ("int32", 22.5), ("float32", "foo")] +) +async def test_invalid_fill_value_raises(data_type: str, fill_value: int | float) -> None: + metadata_dict = { + "zarr_format": 3, + "node_type": "array", + "shape": (1,), + "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}}, + "data_type": data_type, + "chunk_key_encoding": {"name": "default", "separator": "."}, + "codecs": (), + "fill_value": fill_value, # this is not a valid fill value for uint8 + } + with pytest.raises(ValueError, match=rf"fill value .* is not valid for dtype {data_type}"): + ArrayV3Metadata.from_dict(metadata_dict) From 236487d98bc7fb03bdad513eb8d23e838a34af7e Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Fri, 20 Sep 2024 16:55:03 -0700 Subject: [PATCH 4/7] fixup --- src/zarr/core/metadata/v3.py | 38 +++++++++++++++++++++---------- tests/v3/test_array.py | 6 +++-- tests/v3/test_metadata/test_v3.py | 24 +++++++++++++------ 3 files changed, 47 insertions(+), 21 deletions(-) diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 34a33103c6..da5f87cf08 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -216,8 +216,7 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: _ = parse_node_type_array(_data.pop("node_type")) # check that the data_type attribute is valid - if _data["data_type"] not in DataType: - raise ValueError(f"Invalid V3 data_type: {_data['data_type']}") + _ = DataType(_data["data_type"]) # dimension_names key is optional, normalize missing to `None` _data["dimension_names"] = _data.pop("dimension_names", None) @@ -264,23 +263,38 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self: @overload -def parse_fill_value(fill_value: object, dtype: BOOL_DTYPE) -> BOOL: ... +def parse_fill_value( + fill_value: int | float | complex | str | bytes | np.generic | Sequence[Any] | bool | None, + dtype: BOOL_DTYPE, +) -> BOOL: ... @overload -def parse_fill_value(fill_value: object, dtype: INTEGER_DTYPE) -> INTEGER: ... +def parse_fill_value( + fill_value: int | float | complex | str | bytes | np.generic | Sequence[Any] | bool | None, + dtype: INTEGER_DTYPE, +) -> INTEGER: ... @overload -def parse_fill_value(fill_value: object, dtype: FLOAT_DTYPE) -> FLOAT: ... +def parse_fill_value( + fill_value: int | float | complex | str | bytes | np.generic | Sequence[Any] | bool | None, + dtype: FLOAT_DTYPE, +) -> FLOAT: ... @overload -def parse_fill_value(fill_value: object, dtype: COMPLEX_DTYPE) -> COMPLEX: ... +def parse_fill_value( + fill_value: int | float | complex | str | bytes | np.generic | Sequence[Any] | bool | None, + dtype: COMPLEX_DTYPE, +) -> COMPLEX: ... @overload -def parse_fill_value(fill_value: object, dtype: np.dtype[Any]) -> Any: +def parse_fill_value( + fill_value: int | float | complex | str | bytes | np.generic | Sequence[Any] | bool | None, + dtype: np.dtype[Any], +) -> Any: # This dtype[Any] is unfortunately necessary right now. # See https://github.com/zarr-developers/zarr-python/issues/2131#issuecomment-2318010899 # for more details, but `dtype` here (which comes from `parse_dtype`) @@ -292,7 +306,7 @@ def parse_fill_value(fill_value: object, dtype: np.dtype[Any]) -> Any: def parse_fill_value( - fill_value: object, + fill_value: int | float | complex | str | bytes | np.generic | Sequence[Any] | bool | None, dtype: BOOL_DTYPE | INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE | np.dtype[Any], ) -> BOOL | INTEGER | FLOAT | COMPLEX | Any: """ @@ -326,11 +340,11 @@ def parse_fill_value( else: msg = ( f"Got an invalid fill value for complex data type {dtype}." - f"Expected a sequence with 2 elements, but {fill_value} has " + f"Expected a sequence with 2 elements, but {fill_value!r} has " f"length {len(fill_value)}." ) raise ValueError(msg) - msg = f"Cannot parse non-string sequence {fill_value} as a scalar with type {dtype}." + msg = f"Cannot parse non-string sequence {fill_value!r} as a scalar with type {dtype}." raise TypeError(msg) # Cast the fill_value to the given dtype @@ -339,7 +353,7 @@ def parse_fill_value( except (ValueError, OverflowError, TypeError) as e: raise ValueError(f"fill value {fill_value!r} is not valid for dtype {dtype}") from e # Check if the value is still representable by the dtype - if fill_value != casted_value: + if fill_value != casted_value and not (np.isnan(fill_value) and np.isnan(casted_value)): raise ValueError(f"fill value {fill_value!r} is not valid for dtype {dtype}") return casted_value @@ -434,7 +448,7 @@ def from_dtype(cls, dtype: np.dtype[Any]) -> DataType: def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]: try: dtype = np.dtype(data) - except TypeError as e: + except (ValueError, TypeError) as e: raise ValueError(f"Invalid V3 data_type: {data}") from e # check that this is a valid v3 data_type try: diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py index fe5c782a1b..b3362c52b0 100644 --- a/tests/v3/test_array.py +++ b/tests/v3/test_array.py @@ -122,8 +122,10 @@ def test_array_v3_fill_value_default( @pytest.mark.parametrize("store", ["memory"], indirect=True) -@pytest.mark.parametrize("fill_value", [False, 0.0, 1, 2.3]) -@pytest.mark.parametrize("dtype_str", ["bool", "uint8", "float32", "complex64"]) +@pytest.mark.parametrize( + "dtype_str,fill_value", + [("bool", True), ("uint8", 99), ("float32", -99.9), ("complex64", 3 + 4j)], +) def test_array_v3_fill_value(store: MemoryStore, fill_value: int, dtype_str: str) -> None: shape = (10,) arr = Array.create( diff --git a/tests/v3/test_metadata/test_v3.py b/tests/v3/test_metadata/test_v3.py index 7bfa2e0192..4aecbba0b6 100644 --- a/tests/v3/test_metadata/test_v3.py +++ b/tests/v3/test_metadata/test_v3.py @@ -79,8 +79,19 @@ def test_parse_auto_fill_value(dtype_str: str) -> None: assert parse_fill_value(fill_value, dtype) == dtype.type(0) -@pytest.mark.parametrize("fill_value", [0, 1.11, False, True]) -@pytest.mark.parametrize("dtype_str", dtypes) +@pytest.mark.parametrize( + "fill_value,dtype_str", + [ + (True, "bool"), + (False, "bool"), + (-8, "int8"), + (0, "int16"), + (1e10, "uint64"), + (-999, "float32"), + (1e32, "float64"), + (0j, "complex64"), + ], +) def test_parse_fill_value_valid(fill_value: Any, dtype_str: str) -> None: """ Test that parse_fill_value(fill_value, dtype) casts fill_value to the given dtype. @@ -141,8 +152,7 @@ def test_parse_fill_value_invalid_type(fill_value: Any, dtype_str: str) -> None: This test excludes bool because the bool constructor takes anything. """ dtype = np.dtype(dtype_str) - match = "must be" - with pytest.raises(TypeError, match=match): + with pytest.raises(ValueError, match=r"fill value .* is not valid for dtype .*"): parse_fill_value(fill_value, dtype) @@ -269,13 +279,13 @@ async def test_invalid_dtype_raises() -> None: "codecs": (), "fill_value": np.datetime64(0, "ns"), } - with pytest.raises(ValueError, match=r"Invalid V3 data_type"): + with pytest.raises(ValueError, match=r".* is not a valid DataType"): ArrayV3Metadata.from_dict(metadata_dict) @pytest.mark.parametrize("data", ["datetime64[s]", "foo", object()]) def test_parse_invalid_dtype_raises(data): - with pytest.raises(ValueError, match=r"Invalid V3 data_type"): + with pytest.raises(ValueError, match=r"Invalid V3 data_type: .*"): parse_dtype(data) @@ -293,5 +303,5 @@ async def test_invalid_fill_value_raises(data_type: str, fill_value: int | float "codecs": (), "fill_value": fill_value, # this is not a valid fill value for uint8 } - with pytest.raises(ValueError, match=rf"fill value .* is not valid for dtype {data_type}"): + with pytest.raises(ValueError, match=r"fill value .* is not valid for dtype .*"): ArrayV3Metadata.from_dict(metadata_dict) From 1ba45fc0d2b24cc1c73fac1a66aad62aa3107cbe Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Fri, 20 Sep 2024 19:54:21 -0700 Subject: [PATCH 5/7] fix for numpy==1 --- src/zarr/core/metadata/v3.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index da5f87cf08..09a9678971 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, cast, overload if TYPE_CHECKING: @@ -349,12 +350,25 @@ def parse_fill_value( # Cast the fill_value to the given dtype try: - casted_value = np.dtype(dtype).type(fill_value) + # This warning filter can be removed after Zarr supports numpy>=2.0 + # The warning is saying that the future behavior of out of bounds casting will be to raise + # an OverflowError. In the meantime, we allow overflow and catch cases where + # fill_value != casted_value below. + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + casted_value = np.dtype(dtype).type(fill_value) except (ValueError, OverflowError, TypeError) as e: raise ValueError(f"fill value {fill_value!r} is not valid for dtype {dtype}") from e # Check if the value is still representable by the dtype - if fill_value != casted_value and not (np.isnan(fill_value) and np.isnan(casted_value)): - raise ValueError(f"fill value {fill_value!r} is not valid for dtype {dtype}") + if dtype.kind == "f": + # float comparison is not exact, especially when dtype Date: Sat, 21 Sep 2024 20:39:39 -0700 Subject: [PATCH 6/7] custom v3 json encoder --- src/zarr/core/metadata/v3.py | 104 +++++++++++++++++++----------- tests/v3/test_api.py | 4 +- tests/v3/test_metadata/test_v3.py | 39 ++++++++++- 3 files changed, 109 insertions(+), 38 deletions(-) diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 09a9678971..eb4e1155e0 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -71,6 +71,67 @@ def parse_dimension_names(data: object) -> tuple[str | None, ...] | None: raise TypeError(msg) +class V3JsonEncoder(json.JSONEncoder): + def __init__(self, *args: Any, **kwargs: Any): + self.indent = kwargs.pop("indent", config.get("json_indent")) + super().__init__(*args, **kwargs) + + def default(self, o: object) -> Any: + if isinstance(o, np.dtype): + return str(o) + if np.isscalar(o): + out: Any + if hasattr(o, "dtype") and o.dtype.kind == "M" and hasattr(o, "view"): + # https://github.com/zarr-developers/zarr-python/issues/2119 + # `.item()` on a datetime type might or might not return an + # integer, depending on the value. + # Explicitly cast to an int first, and then grab .item() + out = o.view("i8").item() + else: + # convert numpy scalar to python type, and pass + # python types through + out = getattr(o, "item", lambda: o)() + if isinstance(out, complex): + # python complex types are not JSON serializable, so we use the + # serialization defined in the zarr v3 spec + return [out.real, out.imag] + elif np.isnan(out): + return "NaN" + elif np.isinf(out): + return "Infinity" if out > 0 else "-Infinity" + return out + elif isinstance(o, Enum): + return o.name + # this serializes numcodecs compressors + # todo: implement to_dict for codecs + elif isinstance(o, numcodecs.abc.Codec): + config: dict[str, Any] = o.get_config() + return config + else: + return super().default(o) + + +def _replace_special_floats(obj: object) -> Any: + """Helper function to replace NaN/Inf/-Inf values with special strings + + Note: this cannot be done in the V3JsonEncoder because Python's `json.dumps` optimistically + converts NaN/Inf values to special types outside of the encoding step. + """ + print(obj) + if isinstance(obj, float): + if np.isnan(obj): + return "NaN" + elif np.isinf(obj): + return "Infinity" if obj > 0 else "-Infinity" + elif isinstance(obj, dict): + # Recursively replace in dictionaries + return {k: _replace_special_floats(v) for k, v in obj.items()} + elif isinstance(obj, list): + # Recursively replace in lists + return [_replace_special_floats(item) for item in obj] + return obj + + @dataclass(frozen=True, kw_only=True) class ArrayV3Metadata(ArrayMetadata): shape: ChunkCoords @@ -170,41 +231,8 @@ def encode_chunk_key(self, chunk_coords: ChunkCoords) -> str: return self.chunk_key_encoding.encode_chunk_key(chunk_coords) def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: - def _json_convert(o: object) -> Any: - if isinstance(o, np.dtype): - return str(o) - if np.isscalar(o): - out: Any - if hasattr(o, "dtype") and o.dtype.kind == "M" and hasattr(o, "view"): - # https://github.com/zarr-developers/zarr-python/issues/2119 - # `.item()` on a datetime type might or might not return an - # integer, depending on the value. - # Explicitly cast to an int first, and then grab .item() - out = o.view("i8").item() - else: - # convert numpy scalar to python type, and pass - # python types through - out = getattr(o, "item", lambda: o)() - if isinstance(out, complex): - # python complex types are not JSON serializable, so we use the - # serialization defined in the zarr v3 spec - return [out.real, out.imag] - return out - if isinstance(o, Enum): - return o.name - # this serializes numcodecs compressors - # todo: implement to_dict for codecs - elif isinstance(o, numcodecs.abc.Codec): - config: dict[str, Any] = o.get_config() - return config - raise TypeError - - json_indent = config.get("json_indent") - return { - ZARR_JSON: prototype.buffer.from_bytes( - json.dumps(self.to_dict(), default=_json_convert, indent=json_indent).encode() - ) - } + d = _replace_special_floats(self.to_dict()) + return {ZARR_JSON: prototype.buffer.from_bytes(json.dumps(d, cls=V3JsonEncoder).encode())} @classmethod def from_dict(cls, data: dict[str, JSON]) -> Self: @@ -360,7 +388,11 @@ def parse_fill_value( except (ValueError, OverflowError, TypeError) as e: raise ValueError(f"fill value {fill_value!r} is not valid for dtype {dtype}") from e # Check if the value is still representable by the dtype - if dtype.kind == "f": + if fill_value == "NaN" and np.isnan(casted_value): + pass + elif fill_value in ["Infinity", "-Infinity"] and not np.isfinite(casted_value): + pass + elif dtype.kind == "f": # float comparison is not exact, especially when dtype None: # 'r' means read only (must exist) with pytest.raises(FileNotFoundError): zarr.open(store=tmp_path, mode="r") - zarr.ones(store=tmp_path, shape=(3, 3)) + z1 = zarr.ones(store=tmp_path, shape=(3, 3)) + assert z1.fill_value == 1 z2 = zarr.open(store=tmp_path, mode="r") assert isinstance(z2, Array) + assert z2.fill_value == 1 assert (z2[:] == 1).all() with pytest.raises(ValueError): z2[:] = 3 diff --git a/tests/v3/test_metadata/test_v3.py b/tests/v3/test_metadata/test_v3.py index 4aecbba0b6..d4cf0c73e3 100644 --- a/tests/v3/test_metadata/test_v3.py +++ b/tests/v3/test_metadata/test_v3.py @@ -1,9 +1,11 @@ from __future__ import annotations +import json import re from typing import TYPE_CHECKING, Literal from zarr.codecs.bytes import BytesCodec +from zarr.core.buffer import default_buffer_prototype from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding from zarr.core.metadata.v3 import ArrayV3Metadata @@ -89,6 +91,10 @@ def test_parse_auto_fill_value(dtype_str: str) -> None: (1e10, "uint64"), (-999, "float32"), (1e32, "float64"), + (float("NaN"), "float64"), + (np.nan, "float64"), + (np.inf, "float64"), + (-1 * np.inf, "float64"), (0j, "complex64"), ], ) @@ -97,7 +103,12 @@ def test_parse_fill_value_valid(fill_value: Any, dtype_str: str) -> None: Test that parse_fill_value(fill_value, dtype) casts fill_value to the given dtype. """ dtype = np.dtype(dtype_str) - assert parse_fill_value(fill_value, dtype) == dtype.type(fill_value) + parsed = parse_fill_value(fill_value, dtype) + + if np.isnan(fill_value): + assert np.isnan(parsed) + else: + assert parsed == dtype.type(fill_value) @pytest.mark.parametrize("fill_value", ["not a valid value"]) @@ -305,3 +316,29 @@ async def test_invalid_fill_value_raises(data_type: str, fill_value: int | float } with pytest.raises(ValueError, match=r"fill value .* is not valid for dtype .*"): ArrayV3Metadata.from_dict(metadata_dict) + + +@pytest.mark.parametrize("fill_value", [("NaN"), "Infinity", "-Infinity"]) +async def test_special_float_fill_values(fill_value: str) -> None: + metadata_dict = { + "zarr_format": 3, + "node_type": "array", + "shape": (1,), + "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}}, + "data_type": "float64", + "chunk_key_encoding": {"name": "default", "separator": "."}, + "codecs": (), + "fill_value": fill_value, # this is not a valid fill value for uint8 + } + m = ArrayV3Metadata.from_dict(metadata_dict) + d = json.loads(m.to_buffer_dict(default_buffer_prototype())["zarr.json"].to_bytes()) + assert m.fill_value is not None + if fill_value == "NaN": + assert np.isnan(m.fill_value) + assert d["fill_value"] == "NaN" + elif fill_value == "Infinity": + assert np.isposinf(m.fill_value) + assert d["fill_value"] == "Infinity" + elif fill_value == "-Infinity": + assert np.isneginf(m.fill_value) + assert d["fill_value"] == "-Infinity" From 9736b4b7ba06f1a9210775b785510384b9771a8b Mon Sep 17 00:00:00 2001 From: Joe Hamman Date: Mon, 23 Sep 2024 11:39:45 -0700 Subject: [PATCH 7/7] Update src/zarr/core/metadata/v3.py Co-authored-by: Davis Bennett --- src/zarr/core/metadata/v3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index eb4e1155e0..603cd343af 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -117,7 +117,6 @@ def _replace_special_floats(obj: object) -> Any: Note: this cannot be done in the V3JsonEncoder because Python's `json.dumps` optimistically converts NaN/Inf values to special types outside of the encoding step. """ - print(obj) if isinstance(obj, float): if np.isnan(obj): return "NaN"