Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add string and bytes dtypes plus vlen-utf8 and vlen-bytes codecs #2036

Merged
merged 35 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
c05b9d1
add legacy vlen-utf8 codec
rabernat Jul 14, 2024
c86ddc6
Merge branch 'v3' into ryan/legacy-vlen
rabernat Sep 29, 2024
a322124
got it working again
rabernat Sep 29, 2024
2a1e2e3
got strings working; broke everything else
rabernat Oct 1, 2024
1d3d7a5
change v3.metadata.data_type type
rabernat Oct 1, 2024
cd40b08
merged
rabernat Oct 1, 2024
988f9df
fixed tests
rabernat Oct 1, 2024
507161a
satisfy mypy for tests
rabernat Oct 1, 2024
1ae5e63
make strings work
rabernat Oct 3, 2024
94ecdb5
add missing module
rabernat Oct 3, 2024
2c7d638
Merge branch 'v3' into ryan/legacy-vlen
d-v-b Oct 3, 2024
b1717d8
Merge remote-tracking branch 'upstream/v3' into ryan/legacy-vlen
rabernat Oct 4, 2024
79b7d43
store -> storage
rabernat Oct 4, 2024
a5c2a37
rename module
rabernat Oct 4, 2024
717f0c7
Merge remote-tracking branch 'origin/ryan/legacy-vlen' into ryan/lega…
rabernat Oct 4, 2024
b90d8f3
merged
rabernat Oct 4, 2024
0406ea1
add vlen bytes
rabernat Oct 7, 2024
8e61a18
fix type assertions in test
rabernat Oct 7, 2024
6cf7dde
much better validation of fill value
rabernat Oct 7, 2024
28d58fa
retype parse_fill_value
rabernat Oct 7, 2024
c6de878
tests pass but not mypy
rabernat Oct 7, 2024
4f026db
attempted to change parse_fill_value typing
rabernat Oct 8, 2024
e427c7a
restore DEFAULT_DTYPE
rabernat Oct 8, 2024
7d9d897
fixup
TomAugspurger Oct 8, 2024
0c21994
docstring
TomAugspurger Oct 8, 2024
c12ac41
update test
TomAugspurger Oct 8, 2024
3aeea1e
add better DataType tests
rabernat Oct 8, 2024
cae7055
more progress on typing; still not passing mypy
rabernat Oct 8, 2024
1aeb49a
fix typing yay!
rabernat Oct 8, 2024
6714bad
make types work with numpy <, 2
rabernat Oct 8, 2024
2edf3b8
Apply suggestions from code review
rabernat Oct 8, 2024
12a0d65
Apply suggestions from code review
rabernat Oct 8, 2024
7ba7077
apply Joe's suggestions
rabernat Oct 8, 2024
1e828b4
add missing module
rabernat Oct 8, 2024
ba0f093
make _STRING_DTYPE private to try to make sphinx happy
rabernat Oct 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/zarr/codecs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from zarr.codecs.pipeline import BatchedCodecPipeline
from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation
from zarr.codecs.transpose import TransposeCodec
from zarr.codecs.vlen_utf8 import VLenUTF8Codec
from zarr.codecs.zstd import ZstdCodec

__all__ = [
Expand All @@ -21,5 +22,6 @@
"ShardingCodec",
"ShardingCodecIndexLocation",
"TransposeCodec",
"VLenUTF8Codec",
"ZstdCodec",
]
71 changes: 71 additions & 0 deletions src/zarr/codecs/vlen_utf8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

import numpy as np
from numcodecs.vlen import VLenUTF8

from zarr.abc.codec import ArrayBytesCodec
from zarr.core.buffer import Buffer, NDBuffer
from zarr.core.common import JSON, parse_named_configuration
from zarr.registry import register_codec
from zarr.strings import cast_to_string_dtype

if TYPE_CHECKING:
from typing import Self

from zarr.core.array_spec import ArraySpec


# can use a global because there are no parameters
vlen_utf8_codec = VLenUTF8()


@dataclass(frozen=True)
class VLenUTF8Codec(ArrayBytesCodec):
@classmethod
def from_dict(cls, data: dict[str, JSON]) -> Self:
_, configuration_parsed = parse_named_configuration(
data, "vlen-utf8", require_configuration=False
)
configuration_parsed = configuration_parsed or {}
return cls(**configuration_parsed)

def to_dict(self) -> dict[str, JSON]:
return {"name": "vlen-utf8", "configuration": {}}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thought I had while implementing this: often the original numpy array is a fixed-length type (e.g. <U5). In V2, this dtype could be stored directly in the metadata, whereas now we are losing the 5, which is potentially useful information.

In the future, we may want to try to resurface this information.


def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
return self

async def _decode_single(
self,
chunk_bytes: Buffer,
chunk_spec: ArraySpec,
) -> NDBuffer:
assert isinstance(chunk_bytes, Buffer)

raw_bytes = chunk_bytes.as_array_like()
decoded = vlen_utf8_codec.decode(raw_bytes)
assert decoded.dtype == np.object_
decoded.shape = chunk_spec.shape
# coming out of the code, we know this is safe, so don't issue a warning
as_string_dtype = cast_to_string_dtype(decoded, safe=True)
return chunk_spec.prototype.nd_buffer.from_numpy_array(as_string_dtype)

async def _encode_single(
self,
chunk_array: NDBuffer,
chunk_spec: ArraySpec,
) -> Buffer | None:
assert isinstance(chunk_array, NDBuffer)
return chunk_spec.prototype.buffer.from_bytes(
vlen_utf8_codec.encode(chunk_array.as_numpy_array())
)

def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int:
# what is input_byte_length for an object dtype?
raise NotImplementedError("compute_encoded_size is not implemented for VLen codecs")


register_codec("vlen-utf8", VLenUTF8Codec)
8 changes: 5 additions & 3 deletions src/zarr/core/buffer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,7 @@ class NDBuffer:
"""

def __init__(self, array: NDArrayLike) -> None:
# assert array.ndim > 0
assert array.dtype != object
# assert array.dtype != object
rabernat marked this conversation as resolved.
Show resolved Hide resolved
self._data = array

@classmethod
Expand Down Expand Up @@ -467,9 +466,12 @@ def all_equal(self, other: Any, equal_nan: bool = True) -> bool:
# Handle None fill_value for Zarr V2
return False
# use array_equal to obtain equal_nan=True functionality
# Note from Ryan: doesn't this lead to a huge amount of unnecessary memory allocation on every single chunk?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're OK.

In [10]: a, b = np.broadcast_arrays(np.ones(10), 2)

In [11]: b.base
Out[11]: array(2)

I think that it just does some tricks with the strides or something?

rabernat marked this conversation as resolved.
Show resolved Hide resolved
# Since fill-value is a scalar, isn't there a faster path than allocating a new array for fill value
# every single time we have to write data?
_data, other = np.broadcast_arrays(self._data, other)
return np.array_equal(
self._data, other, equal_nan=equal_nan if self._data.dtype.kind not in "US" else False
self._data, other, equal_nan=equal_nan if self._data.dtype.kind not in "USTO" else False
TomAugspurger marked this conversation as resolved.
Show resolved Hide resolved
)

def fill(self, value: Any) -> None:
Expand Down
1 change: 1 addition & 0 deletions src/zarr/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def reset(self) -> None:
"crc32c": "zarr.codecs.crc32c_.Crc32cCodec",
"sharding_indexed": "zarr.codecs.sharding.ShardingCodec",
"transpose": "zarr.codecs.transpose.TransposeCodec",
"vlen-utf8": "zarr.codecs.vlen_utf8.VLenUTF8Codec",
},
"buffer": "zarr.core.buffer.cpu.Buffer",
"ndbuffer": "zarr.core.buffer.cpu.NDBuffer",
Expand Down
70 changes: 45 additions & 25 deletions src/zarr/core/metadata/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
if TYPE_CHECKING:
from typing import Self

import numpy.typing as npt

from zarr.core.buffer import Buffer, BufferPrototype
from zarr.core.chunk_grids import ChunkGrid
from zarr.core.common import JSON, ChunkCoords
Expand All @@ -20,6 +18,7 @@

import numcodecs.abc
import numpy as np
import numpy.typing as npt

from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec
from zarr.core.array_spec import ArraySpec
Expand All @@ -30,6 +29,7 @@
from zarr.core.config import config
from zarr.core.metadata.common import ArrayMetadata, parse_attributes
from zarr.registry import get_codec_class
from zarr.strings import STRING_DTYPE


def parse_zarr_format(data: object) -> Literal[3]:
Expand Down Expand Up @@ -152,7 +152,7 @@ def _replace_special_floats(obj: object) -> Any:
@dataclass(frozen=True, kw_only=True)
class ArrayV3Metadata(ArrayMetadata):
shape: ChunkCoords
data_type: np.dtype[Any]
data_type: DataType
chunk_grid: ChunkGrid
chunk_key_encoding: ChunkKeyEncoding
fill_value: Any
Expand All @@ -167,7 +167,7 @@ def __init__(
self,
*,
shape: Iterable[int],
data_type: npt.DTypeLike,
data_type: npt.DTypeLike | DataType,
chunk_grid: dict[str, JSON] | ChunkGrid,
chunk_key_encoding: dict[str, JSON] | ChunkKeyEncoding,
fill_value: Any,
Expand All @@ -180,18 +180,18 @@ def __init__(
Because the class is a frozen dataclass, we set attributes using object.__setattr__
"""
shape_parsed = parse_shapelike(shape)
data_type_parsed = parse_dtype(data_type)
data_type_parsed = DataType.parse(data_type)
chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid)
chunk_key_encoding_parsed = ChunkKeyEncoding.from_dict(chunk_key_encoding)
dimension_names_parsed = parse_dimension_names(dimension_names)
fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed)
fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed.to_numpy_dtype())
attributes_parsed = parse_attributes(attributes)
codecs_parsed_partial = parse_codecs(codecs)
storage_transformers_parsed = parse_storage_transformers(storage_transformers)

array_spec = ArraySpec(
shape=shape_parsed,
dtype=data_type_parsed,
dtype=data_type_parsed.to_numpy_dtype(),
fill_value=fill_value_parsed,
order="C", # TODO: order is not needed here.
prototype=default_buffer_prototype(), # TODO: prototype is not needed here.
Expand Down Expand Up @@ -224,11 +224,14 @@ def _validate_metadata(self) -> None:
if self.fill_value is None:
raise ValueError("`fill_value` is required.")
for codec in self.codecs:
codec.validate(shape=self.shape, dtype=self.data_type, chunk_grid=self.chunk_grid)
codec.validate(
shape=self.shape, dtype=self.data_type.to_numpy_dtype(), chunk_grid=self.chunk_grid
)

@property
def dtype(self) -> np.dtype[Any]:
return self.data_type
"""Interpret Zarr dtype as NumPy dtype"""
return self.data_type.to_numpy_dtype()

@property
def ndim(self) -> int:
Expand Down Expand Up @@ -266,13 +269,13 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
_ = parse_node_type_array(_data.pop("node_type"))

# check that the data_type attribute is valid
_ = DataType(_data["data_type"])
data_type = DataType.parse(_data.pop("data_type"))

# dimension_names key is optional, normalize missing to `None`
_data["dimension_names"] = _data.pop("dimension_names", None)
# attributes key is optional, normalize missing to `None`
_data["attributes"] = _data.pop("attributes", None)
return cls(**_data) # type: ignore[arg-type]
return cls(**_data, data_type=data_type) # type: ignore[arg-type]

def to_dict(self) -> dict[str, JSON]:
out_dict = super().to_dict()
Expand Down Expand Up @@ -445,6 +448,7 @@ class DataType(Enum):
float64 = "float64"
complex64 = "complex64"
complex128 = "complex128"
string = "string"
TomAugspurger marked this conversation as resolved.
Show resolved Hide resolved

@property
def byte_count(self) -> int:
Expand Down Expand Up @@ -490,8 +494,16 @@ def to_numpy_shortname(self) -> str:
}
return data_type_to_numpy[self]

def to_numpy_dtype(self) -> np.dtype[Any]:
if self == DataType.string:
return STRING_DTYPE
else:
return np.dtype(self.to_numpy_shortname())

@classmethod
def from_dtype(cls, dtype: np.dtype[Any]) -> DataType:
def from_numpy_dtype(cls, dtype: np.dtype[Any]) -> DataType:
if np.issubdtype(np.str_, dtype):
return DataType.string
dtype_to_data_type = {
"|b1": "bool",
"bool": "bool",
Expand All @@ -511,16 +523,24 @@ def from_dtype(cls, dtype: np.dtype[Any]) -> DataType:
}
return DataType[dtype_to_data_type[dtype.str]]


def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
try:
dtype = np.dtype(data)
except (ValueError, TypeError) as e:
raise ValueError(f"Invalid V3 data_type: {data}") from e
# check that this is a valid v3 data_type
try:
_ = DataType.from_dtype(dtype)
except KeyError as e:
raise ValueError(f"Invalid V3 data_type: {dtype}") from e

return dtype
@classmethod
def parse(cls, dtype: None | DataType | Any) -> DataType:
if dtype is None:
# the default dtype
return DataType.float64
if isinstance(dtype, DataType):
return dtype
try:
return DataType(dtype)
except ValueError:
pass
try:
dtype = np.dtype(dtype)
except (ValueError, TypeError) as e:
raise ValueError(f"Invalid V3 data_type: {dtype}") from e
# check that this is a valid v3 data_type
try:
data_type = DataType.from_numpy_dtype(dtype)
except KeyError as e:
raise ValueError(f"Invalid V3 data_type: {dtype}") from e
return data_type
36 changes: 36 additions & 0 deletions src/zarr/strings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Any
from warnings import warn

import numpy as np

try:
STRING_DTYPE = np.dtype("T")
NUMPY_SUPPORTS_VLEN_STRING = True
except TypeError:
STRING_DTYPE = np.dtype("object")
NUMPY_SUPPORTS_VLEN_STRING = False


def cast_to_string_dtype(
data: np.ndarray[Any, np.dtype[Any]], safe: bool = False
) -> np.ndarray[Any, np.dtype[Any]]:
if np.issubdtype(data.dtype, np.str_):
return data
if np.issubdtype(data.dtype, np.object_):
if NUMPY_SUPPORTS_VLEN_STRING:
try:
# cast to variable-length string dtype, fail if object contains non-string data
# mypy says "error: Unexpected keyword argument "coerce" for "StringDType" [call-arg]"
return data.astype(np.dtypes.StringDType(coerce=False), copy=False) # type: ignore[call-arg]
except ValueError as e:
raise ValueError("Cannot cast object dtype to string dtype") from e
else:
out = data.astype(np.str_)
if not safe:
warn(
f"Casted object dtype to string dtype {out.dtype}. To avoid this warning, "
"cast the data to a string dtype before passing to Zarr or upgrade to NumPy >= 2.",
stacklevel=2,
)
return out
raise ValueError(f"Cannot cast dtype {data.dtype} to string dtype")
51 changes: 51 additions & 0 deletions tests/v3/test_codecs/test_vlen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Any

import numpy as np
import pytest

from zarr import Array
from zarr.abc.store import Store
from zarr.codecs import VLenUTF8Codec
from zarr.core.metadata.v3 import ArrayV3Metadata, DataType
from zarr.storage.common import StorePath
from zarr.strings import NUMPY_SUPPORTS_VLEN_STRING

numpy_str_dtypes: list[type | None] = [None, str, np.dtypes.StrDType]
expected_zarr_string_dtype: np.dtype[Any]
if NUMPY_SUPPORTS_VLEN_STRING:
numpy_str_dtypes.append(np.dtypes.StringDType)
expected_zarr_string_dtype = np.dtypes.StringDType()
else:
expected_zarr_string_dtype = np.dtype("O")


@pytest.mark.parametrize("store", ["memory", "local"], indirect=["store"])
@pytest.mark.parametrize("dtype", numpy_str_dtypes)
async def test_vlen_string(store: Store, dtype: None | np.dtype[Any]) -> None:
strings = ["hello", "world", "this", "is", "a", "test"]
data = np.array(strings).reshape((2, 3))
if dtype is not None:
data = data.astype(dtype)

sp = StorePath(store, path="string")
a = Array.create(
sp,
shape=data.shape,
chunk_shape=data.shape,
dtype=data.dtype,
fill_value="",
codecs=[VLenUTF8Codec()],
)
assert isinstance(a.metadata, ArrayV3Metadata) # needed for mypy

a[:, :] = data
assert np.array_equal(data, a[:, :])
assert a.metadata.data_type == DataType.string
assert a.dtype == expected_zarr_string_dtype

# test round trip
b = Array.open(sp)
assert isinstance(b.metadata, ArrayV3Metadata) # needed for mypy
assert np.array_equal(data, b[:, :])
assert b.metadata.data_type == DataType.string
assert a.dtype == expected_zarr_string_dtype
1 change: 1 addition & 0 deletions tests/v3/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def test_config_defaults_set() -> None:
"crc32c": "zarr.codecs.crc32c_.Crc32cCodec",
"sharding_indexed": "zarr.codecs.sharding.ShardingCodec",
"transpose": "zarr.codecs.transpose.TransposeCodec",
"vlen-utf8": "zarr.codecs.vlen_utf8.VLenUTF8Codec",
},
}
]
Expand Down
Loading