Skip to content

Commit

Permalink
fix: validate v3 dtypes when loading/creating v3 metadata (#2209)
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamman authored Sep 23, 2024
1 parent fb28fa5 commit 2d3a36c
Show file tree
Hide file tree
Showing 9 changed files with 291 additions and 80 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ filterwarnings = [
"ignore:PY_SSIZE_T_CLEAN will be required.*:DeprecationWarning",
"ignore:The loop argument is deprecated since Python 3.8.*:DeprecationWarning",
"ignore:Creating a zarr.buffer.gpu.*:UserWarning",
"ignore:Duplicate name:UserWarning", # from ZipFile
]
markers = [
"gpu: mark a test as requiring CuPy and GPU"
Expand Down
5 changes: 2 additions & 3 deletions src/zarr/core/array_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal

from zarr.core.common import parse_dtype, parse_fill_value, parse_order, parse_shapelike
from zarr.core.common import parse_fill_value, parse_order, parse_shapelike

if TYPE_CHECKING:
import numpy as np
Expand All @@ -29,12 +29,11 @@ def __init__(
prototype: BufferPrototype,
) -> None:
shape_parsed = parse_shapelike(shape)
dtype_parsed = parse_dtype(dtype)
fill_value_parsed = parse_fill_value(fill_value)
order_parsed = parse_order(order)

object.__setattr__(self, "shape", shape_parsed)
object.__setattr__(self, "dtype", dtype_parsed)
object.__setattr__(self, "dtype", dtype)
object.__setattr__(self, "fill_value", fill_value_parsed)
object.__setattr__(self, "order", order_parsed)
object.__setattr__(self, "prototype", prototype)
Expand Down
7 changes: 0 additions & 7 deletions src/zarr/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Iterator

import numpy as np
import numpy.typing as npt

ZARR_JSON = "zarr.json"
ZARRAY_JSON = ".zarray"
Expand Down Expand Up @@ -155,11 +153,6 @@ def parse_shapelike(data: int | Iterable[int]) -> tuple[int, ...]:
return data_tuple


def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
# todo: real validation
return np.dtype(data)


def parse_fill_value(data: Any) -> Any:
# todo: real validation
return data
Expand Down
29 changes: 25 additions & 4 deletions src/zarr/core/metadata/v2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections.abc import Iterable
from enum import Enum
from typing import TYPE_CHECKING

if TYPE_CHECKING:
Expand All @@ -21,7 +22,7 @@
from zarr.core.array_spec import ArraySpec
from zarr.core.chunk_grids import RegularChunkGrid
from zarr.core.chunk_key_encodings import parse_separator
from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_dtype, parse_shapelike
from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_shapelike
from zarr.core.config import config, parse_indexing_order
from zarr.core.metadata.common import ArrayMetadata, parse_attributes

Expand Down Expand Up @@ -100,9 +101,24 @@ def _json_convert(
else:
return o.descr
if np.isscalar(o):
# convert numpy scalar to python type, and pass
# python types through
return getattr(o, "item", lambda: o)()
out: Any
if hasattr(o, "dtype") and o.dtype.kind == "M" and hasattr(o, "view"):
# https://github.com/zarr-developers/zarr-python/issues/2119
# `.item()` on a datetime type might or might not return an
# integer, depending on the value.
# Explicitly cast to an int first, and then grab .item()
out = o.view("i8").item()
else:
# convert numpy scalar to python type, and pass
# python types through
out = getattr(o, "item", lambda: o)()
if isinstance(out, complex):
# python complex types are not JSON serializable, so we use the
# serialization defined in the zarr v3 spec
return [out.real, out.imag]
return out
if isinstance(o, Enum):
return o.name
raise TypeError

zarray_dict = self.to_dict()
Expand Down Expand Up @@ -157,6 +173,11 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self:
return replace(self, attributes=attributes)


def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
# todo: real validation
return np.dtype(data)


def parse_zarr_format(data: object) -> Literal[2]:
if data == 2:
return 2
Expand Down
Loading

0 comments on commit 2d3a36c

Please sign in to comment.