Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always skip reads when completely overwriting chunks #2784

Merged
merged 19 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/zarr/abc/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ async def encode(
@abstractmethod
async def read(
self,
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple]],
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
out: NDBuffer,
drop_axes: tuple[int, ...] = (),
) -> None:
Expand All @@ -379,7 +379,7 @@ async def read(
@abstractmethod
async def write(
self,
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple]],
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
value: NDBuffer,
drop_axes: tuple[int, ...] = (),
) -> None:
Expand Down
14 changes: 9 additions & 5 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,9 @@ async def _decode_single(
chunk_spec,
chunk_selection,
out_selection,
is_complete_shard,
)
for chunk_coords, chunk_selection, out_selection in indexer
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
],
out,
)
Expand Down Expand Up @@ -486,7 +487,7 @@ async def _decode_partial_single(
)

indexed_chunks = list(indexer)
all_chunk_coords = {chunk_coords for chunk_coords, _, _ in indexed_chunks}
all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks}

# reading bytes of all requested chunks
shard_dict: ShardMapping = {}
Expand Down Expand Up @@ -524,8 +525,9 @@ async def _decode_partial_single(
chunk_spec,
chunk_selection,
out_selection,
is_complete_shard,
)
for chunk_coords, chunk_selection, out_selection in indexer
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
],
out,
)
Expand Down Expand Up @@ -558,8 +560,9 @@ async def _encode_single(
chunk_spec,
chunk_selection,
out_selection,
is_complete_shard,
)
for chunk_coords, chunk_selection, out_selection in indexer
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
],
shard_array,
)
Expand Down Expand Up @@ -601,8 +604,9 @@ async def _encode_partial_single(
chunk_spec,
chunk_selection,
out_selection,
is_complete_shard,
)
for chunk_coords, chunk_selection, out_selection in indexer
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
],
shard_array,
)
Expand Down
6 changes: 4 additions & 2 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,8 +1290,9 @@ async def _get_selection(
self.metadata.get_chunk_spec(chunk_coords, _config, prototype=prototype),
chunk_selection,
out_selection,
is_complete_chunk,
)
for chunk_coords, chunk_selection, out_selection in indexer
for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer
],
out_buffer,
drop_axes=indexer.drop_axes,
Expand Down Expand Up @@ -1417,8 +1418,9 @@ async def _set_selection(
self.metadata.get_chunk_spec(chunk_coords, _config, prototype),
chunk_selection,
out_selection,
is_complete_chunk,
)
for chunk_coords, chunk_selection, out_selection in indexer
for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer
],
value_buffer,
drop_axes=indexer.drop_axes,
Expand Down
60 changes: 34 additions & 26 deletions src/zarr/core/codec_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from zarr.core.common import ChunkCoords, concurrent_map
from zarr.core.config import config
from zarr.core.indexing import SelectorTuple, is_scalar, is_total_slice
from zarr.core.indexing import SelectorTuple, is_scalar
from zarr.core.metadata.v2 import _default_fill_value
from zarr.registry import register_pipeline

Expand Down Expand Up @@ -230,18 +230,18 @@ async def encode_partial_batch(

async def read_batch(
self,
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple]],
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
out: NDBuffer,
drop_axes: tuple[int, ...] = (),
) -> None:
if self.supports_partial_decode:
chunk_array_batch = await self.decode_partial_batch(
[
(byte_getter, chunk_selection, chunk_spec)
for byte_getter, chunk_spec, chunk_selection, _ in batch_info
for byte_getter, chunk_spec, chunk_selection, *_ in batch_info
]
)
for chunk_array, (_, chunk_spec, _, out_selection) in zip(
for chunk_array, (_, chunk_spec, _, out_selection, _) in zip(
chunk_array_batch, batch_info, strict=False
):
if chunk_array is not None:
Expand All @@ -260,22 +260,19 @@ async def read_batch(
out[out_selection] = fill_value
else:
chunk_bytes_batch = await concurrent_map(
[
(byte_getter, array_spec.prototype)
for byte_getter, array_spec, _, _ in batch_info
],
[(byte_getter, array_spec.prototype) for byte_getter, array_spec, *_ in batch_info],
lambda byte_getter, prototype: byte_getter.get(prototype),
config.get("async.concurrency"),
)
chunk_array_batch = await self.decode_batch(
[
(chunk_bytes, chunk_spec)
for chunk_bytes, (_, chunk_spec, _, _) in zip(
for chunk_bytes, (_, chunk_spec, *_) in zip(
chunk_bytes_batch, batch_info, strict=False
)
],
)
for chunk_array, (_, chunk_spec, chunk_selection, out_selection) in zip(
for chunk_array, (_, chunk_spec, chunk_selection, out_selection, _) in zip(
chunk_array_batch, batch_info, strict=False
):
if chunk_array is not None:
Expand All @@ -296,9 +293,10 @@ def _merge_chunk_array(
out_selection: SelectorTuple,
chunk_spec: ArraySpec,
chunk_selection: SelectorTuple,
is_complete_chunk: bool,
drop_axes: tuple[int, ...],
) -> NDBuffer:
if is_total_slice(chunk_selection, chunk_spec.shape) and value.shape == chunk_spec.shape:
if is_complete_chunk and value.shape == chunk_spec.shape:
return value
if existing_chunk_array is None:
chunk_array = chunk_spec.prototype.nd_buffer.create(
Expand Down Expand Up @@ -327,7 +325,7 @@ def _merge_chunk_array(

async def write_batch(
self,
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple]],
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
value: NDBuffer,
drop_axes: tuple[int, ...] = (),
) -> None:
Expand All @@ -337,14 +335,14 @@ async def write_batch(
await self.encode_partial_batch(
[
(byte_setter, value, chunk_selection, chunk_spec)
for byte_setter, chunk_spec, chunk_selection, out_selection in batch_info
for byte_setter, chunk_spec, chunk_selection, out_selection, _ in batch_info
],
)
else:
await self.encode_partial_batch(
[
(byte_setter, value[out_selection], chunk_selection, chunk_spec)
for byte_setter, chunk_spec, chunk_selection, out_selection in batch_info
for byte_setter, chunk_spec, chunk_selection, out_selection, _ in batch_info
],
)

Expand All @@ -361,33 +359,43 @@ async def _read_key(
chunk_bytes_batch = await concurrent_map(
[
(
None if is_total_slice(chunk_selection, chunk_spec.shape) else byte_setter,
None if is_complete_chunk else byte_setter,
chunk_spec.prototype,
)
for byte_setter, chunk_spec, chunk_selection, _ in batch_info
for byte_setter, chunk_spec, chunk_selection, _, is_complete_chunk in batch_info
],
_read_key,
config.get("async.concurrency"),
)
chunk_array_decoded = await self.decode_batch(
[
(chunk_bytes, chunk_spec)
for chunk_bytes, (_, chunk_spec, _, _) in zip(
for chunk_bytes, (_, chunk_spec, *_) in zip(
chunk_bytes_batch, batch_info, strict=False
)
],
)

chunk_array_merged = [
self._merge_chunk_array(
chunk_array, value, out_selection, chunk_spec, chunk_selection, drop_axes
)
for chunk_array, (_, chunk_spec, chunk_selection, out_selection) in zip(
chunk_array_decoded, batch_info, strict=False
chunk_array,
value,
out_selection,
chunk_spec,
chunk_selection,
is_complete_chunk,
drop_axes,
)
for chunk_array, (
_,
chunk_spec,
chunk_selection,
out_selection,
is_complete_chunk,
) in zip(chunk_array_decoded, batch_info, strict=False)
]
chunk_array_batch: list[NDBuffer | None] = []
for chunk_array, (_, chunk_spec, _, _) in zip(
for chunk_array, (_, chunk_spec, *_) in zip(
chunk_array_merged, batch_info, strict=False
):
if chunk_array is None:
Expand All @@ -403,7 +411,7 @@ async def _read_key(
chunk_bytes_batch = await self.encode_batch(
[
(chunk_array, chunk_spec)
for chunk_array, (_, chunk_spec, _, _) in zip(
for chunk_array, (_, chunk_spec, *_) in zip(
chunk_array_batch, batch_info, strict=False
)
],
Expand All @@ -418,7 +426,7 @@ async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> Non
await concurrent_map(
[
(byte_setter, chunk_bytes)
for chunk_bytes, (byte_setter, _, _, _) in zip(
for chunk_bytes, (byte_setter, *_) in zip(
chunk_bytes_batch, batch_info, strict=False
)
],
Expand Down Expand Up @@ -446,7 +454,7 @@ async def encode(

async def read(
self,
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple]],
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
out: NDBuffer,
drop_axes: tuple[int, ...] = (),
) -> None:
Expand All @@ -461,7 +469,7 @@ async def read(

async def write(
self,
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple]],
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
value: NDBuffer,
drop_axes: tuple[int, ...] = (),
) -> None:
Expand Down
Loading
Loading