Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: validate v3 dtypes when loading/creating v3 metadata #2209

Merged
merged 10 commits into from
Sep 23, 2024
4 changes: 2 additions & 2 deletions src/zarr/core/array_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal

from zarr.core.common import parse_dtype, parse_fill_value, parse_order, parse_shapelike
from zarr.core.common import parse_fill_value, parse_order, parse_shapelike

if TYPE_CHECKING:
import numpy as np
Expand All @@ -29,7 +29,7 @@ def __init__(
prototype: BufferPrototype,
) -> None:
shape_parsed = parse_shapelike(shape)
dtype_parsed = parse_dtype(dtype)
dtype_parsed = dtype # parsing is likely not needed here
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dtype_parsed = dtype # parsing is likely not needed here
dtype_parsed = dtype

The input type is already np.dtype[Any] so I don't think we need to parse this.

fill_value_parsed = parse_fill_value(fill_value)
order_parsed = parse_order(order)

Expand Down
7 changes: 0 additions & 7 deletions src/zarr/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Iterator

import numpy as np
import numpy.typing as npt

ZARR_JSON = "zarr.json"
ZARRAY_JSON = ".zarray"
Expand Down Expand Up @@ -154,11 +152,6 @@ def parse_shapelike(data: int | Iterable[int]) -> tuple[int, ...]:
return data_tuple


def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
# todo: real validation
return np.dtype(data)


def parse_fill_value(data: Any) -> Any:
# todo: real validation
return data
Expand Down
7 changes: 6 additions & 1 deletion src/zarr/core/metadata/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from zarr.core.array_spec import ArraySpec
from zarr.core.chunk_grids import RegularChunkGrid
from zarr.core.chunk_key_encodings import parse_separator
from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_dtype, parse_shapelike
from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_shapelike
from zarr.core.config import config, parse_indexing_order
from zarr.core.metadata.common import ArrayMetadata, parse_attributes

Expand Down Expand Up @@ -157,6 +157,11 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self:
return replace(self, attributes=attributes)


def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
# todo: real validation
return np.dtype(data)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we now parse v2 and v3 dtypes differently.



def parse_zarr_format(data: object) -> Literal[2]:
if data == 2:
return 2
Expand Down
32 changes: 31 additions & 1 deletion src/zarr/core/metadata/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from zarr.core.buffer import default_buffer_prototype
from zarr.core.chunk_grids import ChunkGrid, RegularChunkGrid
from zarr.core.chunk_key_encodings import ChunkKeyEncoding
from zarr.core.common import ZARR_JSON, parse_dtype, parse_named_configuration, parse_shapelike
from zarr.core.common import ZARR_JSON, parse_named_configuration, parse_shapelike
from zarr.core.config import config
from zarr.core.metadata.common import ArrayMetadata, parse_attributes
from zarr.registry import get_codec_class
Expand Down Expand Up @@ -215,6 +215,10 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
# check that the node_type attribute is correct
_ = parse_node_type_array(_data.pop("node_type"))

# check that the data_type attribute is valid
if _data["data_type"] not in DataType:
raise ValueError(f"Invalid V3 data_type: {_data['data_type']}")

# dimension_names key is optional, normalize missing to `None`
_data["dimension_names"] = _data.pop("dimension_names", None)
# attributes key is optional, normalize missing to `None`
Expand Down Expand Up @@ -345,8 +349,11 @@ class DataType(Enum):
uint16 = "uint16"
uint32 = "uint32"
uint64 = "uint64"
float16 = "float16"
float32 = "float32"
float64 = "float64"
complex64 = "complex64"
complex128 = "complex128"

@property
def byte_count(self) -> int:
Expand All @@ -360,8 +367,11 @@ def byte_count(self) -> int:
DataType.uint16: 2,
DataType.uint32: 4,
DataType.uint64: 8,
DataType.float16: 2,
DataType.float32: 4,
DataType.float64: 8,
DataType.complex64: 8,
DataType.complex128: 16,
}
return data_type_byte_counts[self]

Expand All @@ -381,8 +391,11 @@ def to_numpy_shortname(self) -> str:
DataType.uint16: "u2",
DataType.uint32: "u4",
DataType.uint64: "u8",
DataType.float16: "f2",
DataType.float32: "f4",
DataType.float64: "f8",
DataType.complex64: "c8",
DataType.complex128: "c16",
}
return data_type_to_numpy[self]

Expand All @@ -399,7 +412,24 @@ def from_dtype(cls, dtype: np.dtype[Any]) -> DataType:
"<u2": "uint16",
"<u4": "uint32",
"<u8": "uint64",
"<f2": "float16",
"<f4": "float32",
"<f8": "float64",
"<c8": "complex64",
"<c16": "complex128",
}
return DataType[dtype_to_data_type[dtype.str]]


def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
d-v-b marked this conversation as resolved.
Show resolved Hide resolved
try:
dtype = np.dtype(data)
except TypeError as e:
raise ValueError(f"Invalid V3 data_type: {data}") from e
# check that this is a valid v3 data_type
try:
_ = DataType.from_dtype(dtype)
except KeyError as e:
raise ValueError(f"Invalid V3 data_type: {dtype}") from e

return dtype
4 changes: 3 additions & 1 deletion src/zarr/testing/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
paths = st.lists(node_names, min_size=1).map(lambda x: "/".join(x)) | st.just("/")
np_arrays = npst.arrays(
# TODO: re-enable timedeltas once they are supported
dtype=npst.scalar_dtypes().filter(lambda x: x.kind != "m"),
dtype=npst.scalar_dtypes().filter(
lambda x: (x.kind not in ["m", "M"]) and (x.byteorder not in [">"])
),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not clear to me how endianness works in v3. As far as I can tell, we are not handling this in a meaningful way today.

cc @dcherian and @d-v-b

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I concur. It's not being handled today

shape=npst.array_shapes(max_dims=4),
)
stores = st.builds(MemoryStore, st.just({}), mode=st.just("w"))
Expand Down
50 changes: 37 additions & 13 deletions tests/v3/test_metadata/test_v3.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from __future__ import annotations

import json
import re
from typing import TYPE_CHECKING, Literal

from zarr.codecs.bytes import BytesCodec
from zarr.core.buffer import default_buffer_prototype
from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding
from zarr.core.metadata.v3 import ArrayV3Metadata

Expand All @@ -19,7 +17,12 @@
import numpy as np
import pytest

from zarr.core.metadata.v3 import parse_dimension_names, parse_fill_value, parse_zarr_format
from zarr.core.metadata.v3 import (
parse_dimension_names,
parse_dtype,
parse_fill_value,
parse_zarr_format,
)

bool_dtypes = ("bool",)

Expand Down Expand Up @@ -234,22 +237,43 @@ def test_metadata_to_dict(
assert observed == expected


@pytest.mark.parametrize("fill_value", [-1, 0, 1, 2932897])
@pytest.mark.parametrize("precision", ["ns", "D"])
async def test_datetime_metadata(fill_value: int, precision: str) -> None:
# @pytest.mark.parametrize("fill_value", [-1, 0, 1, 2932897])
# @pytest.mark.parametrize("precision", ["ns", "D"])
# async def test_datetime_metadata(fill_value: int, precision: str) -> None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test was creating an invalid dtype (datetime). We can leave it in place and bring it back when there is support for this dtype in v3.

# metadata_dict = {
# "zarr_format": 3,
# "node_type": "array",
# "shape": (1,),
# "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}},
# "data_type": f"<M8[{precision}]",
# "chunk_key_encoding": {"name": "default", "separator": "."},
# "codecs": (),
# "fill_value": np.datetime64(fill_value, precision),
# }
# metadata = ArrayV3Metadata.from_dict(metadata_dict)
# # ensure there isn't a TypeError here.
# d = metadata.to_buffer_dict(default_buffer_prototype())

# result = json.loads(d["zarr.json"].to_bytes())
# assert result["fill_value"] == fill_value


async def test_invalid_dtype_raises() -> None:
metadata_dict = {
"zarr_format": 3,
"node_type": "array",
"shape": (1,),
"chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}},
"data_type": f"<M8[{precision}]",
"data_type": "<M8[ns]",
"chunk_key_encoding": {"name": "default", "separator": "."},
"codecs": (),
"fill_value": np.datetime64(fill_value, precision),
"fill_value": np.datetime64(0, "ns"),
}
metadata = ArrayV3Metadata.from_dict(metadata_dict)
# ensure there isn't a TypeError here.
d = metadata.to_buffer_dict(default_buffer_prototype())
with pytest.raises(ValueError, match=r"Invalid V3 data_type"):
ArrayV3Metadata.from_dict(metadata_dict)


result = json.loads(d["zarr.json"].to_bytes())
assert result["fill_value"] == fill_value
@pytest.mark.parametrize("data", ["datetime64[s]", "foo", object()])
def test_parse_invalid_dtype_raises(data):
with pytest.raises(ValueError, match=r"Invalid V3 data_type"):
parse_dtype(data)
Loading