Skip to content

Commit

Permalink
ensure tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
d-v-b committed Dec 20, 2024
1 parent 68465db commit d7bb121
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
10 changes: 7 additions & 3 deletions src/zarr/api/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ async def create_array(
shape: ChunkCoords,
dtype: npt.DTypeLike,
chunk_shape: ChunkCoords,
shard_shape: ChunkCoords | None,
shard_shape: ChunkCoords | None = None,
filters: Iterable[dict[str, JSON] | Codec] = (),
compressors: Iterable[dict[str, JSON] | Codec] = (),
fill_value: Any | None = 0,
Expand Down Expand Up @@ -1016,14 +1016,18 @@ async def create_array(
_dtype_parsed = parse_dtype(dtype, zarr_format=zarr_format)
config_parsed = parse_array_config(config)
if zarr_format == 2:
if shard_shape is not None or shard_shape != "auto":
if shard_shape is not None:
msg = (
'Zarr v2 arrays can only be created with `shard_shape` set to `None` or `"auto"`.'
f"Got `shard_shape={shard_shape}` instead."
)

raise ValueError(msg)
compressor, *rest = compressors
if len(tuple(compressors)) > 1:
compressor, *rest = compressors
else:
compressor = None
rest = ()
filters = (*filters, *rest)
if dimension_names is not None:
raise ValueError("Zarr v2 arrays do not support dimension names.")
Expand Down
25 changes: 20 additions & 5 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,14 @@ def test_read(store: Store) -> None:
"""
# create an array and a group
_ = create_group(store=store, path="group", attributes={"node_type": "group"})
_ = create_array(store=store, path="array", shape=(10, 10), attributes={"node_type": "array"})
_ = create_array(
store=store,
path="array",
shape=(10, 10),
chunk_shape=(1, 1),
dtype="uint8",
attributes={"node_type": "array"},
)

group_r = read(store, path="group")
assert isinstance(group_r, Group)
Expand All @@ -89,7 +96,9 @@ def test_create_array(store: Store) -> None:
shape = (10, 10)
path = "foo"
data_val = 1
array_w = create_array(store, path=path, shape=shape, attributes=attrs)
array_w = create_array(
store, path=path, shape=shape, attributes=attrs, chunk_shape=shape, dtype="uint8"
)
array_w[:] = data_val
assert array_w.shape == shape
assert array_w.attrs == attrs
Expand All @@ -107,7 +116,13 @@ def test_read_array(store: Store) -> None:
for zarr_format in (2, 3):
attrs = {"zarr_format": zarr_format}
node_w = create_array(
store, path=path, shape=shape, attributes=attrs, zarr_format=zarr_format
store,
path=path,
shape=shape,
attributes=attrs,
zarr_format=zarr_format,
chunk_shape=shape,
dtype="uint8",
)
node_w[:] = data_val

Expand Down Expand Up @@ -1214,9 +1229,9 @@ async def test_create_array_v2(store: MemoryStore) -> None:
store=store,
dtype=dtype,
shape=(10,),
shard_shape=(4,),
shard_shape=None,
chunk_shape=(4,),
zarr_format=3,
zarr_format=2,
filters=(Delta(dtype=dtype),),
compressors=(Zstd(level=3),),
)

0 comments on commit d7bb121

Please sign in to comment.