Skip to content

Commit

Permalink
re-implemented concatenation through concatenation of the wrapped str…
Browse files Browse the repository at this point in the history
…uctured array
  • Loading branch information
TomNicholas committed Mar 18, 2024
1 parent 20f2ded commit be8af12
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 110 deletions.
7 changes: 1 addition & 6 deletions virtualizarr/manifests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,4 @@
# This is just to avoid conflicting with some type of file called manifest that .gitignore recommends ignoring.

from .array import ManifestArray # type: ignore # noqa
from .manifest import ( # type: ignore # noqa
ChunkEntry,
ChunkManifest,
concat_manifests,
stack_manifests,
)
from .manifest import ChunkEntry, ChunkManifest # type: ignore # noqa
12 changes: 7 additions & 5 deletions virtualizarr/manifests/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from ..zarr import Codec, ZArray
from .manifest import concat_manifests, stack_manifests
from .manifest import ChunkManifest

if TYPE_CHECKING:
from .array import ManifestArray
Expand Down Expand Up @@ -123,10 +123,11 @@ def concatenate(
new_shape = list(first_shape)
new_shape[axis] = new_length_along_concat_axis

concatenated_manifest = concat_manifests(
[arr.manifest for arr in arrays],
concatenated_manifest_entries = np.concatenate(
[arr.manifest.entries for arr in arrays],
axis=axis,
)
concatenated_manifest = ChunkManifest(entries=concatenated_manifest_entries)

new_zarray = ZArray(
chunks=first_arr.chunks,
Expand Down Expand Up @@ -206,10 +207,11 @@ def stack(
new_shape = list(first_shape)
new_shape.insert(axis, length_along_new_stacked_axis)

stacked_manifest = stack_manifests(
[arr.manifest for arr in arrays],
stacked_manifest_entries = np.stack(
[arr.manifest.entries for arr in arrays],
axis=axis,
)
stacked_manifest = ChunkManifest(entries=stacked_manifest_entries)

# chunk size has changed because a length-1 axis has been inserted
old_chunks = first_arr.chunks
Expand Down
86 changes: 3 additions & 83 deletions virtualizarr/manifests/manifest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Any, Iterable, Iterator, List, Mapping, NewType, Tuple, Union, cast
from typing import Any, Iterable, Iterator, List, NewType, Tuple, Union, cast

import numpy as np
from pydantic import BaseModel, ConfigDict
Expand All @@ -13,9 +13,7 @@
_CHUNK_KEY = rf"^{_INTEGER}+({_SEPARATOR}{_INTEGER})*$" # matches 1 integer, optionally followed by more integers each separated by a separator (i.e. a period)


ChunkDict = NewType(
"ChunkDict", dict[ChunkKey, dict[str, Union[str, int]]]
) # just the .zattrs (for one array or for the whole store/group)
ChunkDict = NewType("ChunkDict", dict[ChunkKey, dict[str, Union[str, int]]])


class ChunkEntry(BaseModel):
Expand Down Expand Up @@ -190,7 +188,7 @@ def from_kerchunk_chunk_dict(cls, kerchunk_chunk_dict) -> "ChunkManifest":
chunkentries = {
k: ChunkEntry.from_kerchunk(v) for k, v in kerchunk_chunk_dict.items()
}
return ChunkManifest(entries=chunkentries)
return ChunkManifest.from_dict(chunkentries)


def split(key: ChunkKey) -> Tuple[int, ...]:
Expand Down Expand Up @@ -230,81 +228,3 @@ def get_chunk_grid_shape(chunk_keys: Iterable[ChunkKey]) -> Tuple[int, ...]:
max(indices_along_one_dim) + 1 for indices_along_one_dim in zipped_indices
)
return chunk_grid_shape


def concat_manifests(manifests: List["ChunkManifest"], axis: int) -> "ChunkManifest":
"""
Concatenate manifests along an existing dimension.
This only requires adjusting one index of chunk keys along a single dimension.
Note axis is not expected to be negative.
"""
if len(manifests) == 1:
return manifests[0]

chunk_grid_shapes = [manifest.shape_chunk_grid for manifest in manifests]
lengths_along_concat_dim = [shape[axis] for shape in chunk_grid_shapes]

# Note we do not need to change the keys of the first manifest
chunk_index_offsets = np.cumsum(lengths_along_concat_dim)[:-1]
new_entries = [
adjust_chunk_keys(manifest.entries, axis, offset)
for manifest, offset in zip(manifests[1:], chunk_index_offsets)
]
all_entries = [manifests[0].entries] + new_entries
merged_entries = dict((k, v) for d in all_entries for k, v in d.items())

# Arguably don't need to re-perform validation checks on a manifest we created out of already-validated manifests
# Could use pydantic's model_construct classmethod to skip these checks
# But we should actually performance test it because it might be pointless, and current implementation is safer
return ChunkManifest(entries=merged_entries)


def adjust_chunk_keys(
entries: Mapping[ChunkKey, ChunkEntry], axis: int, offset: int
) -> Mapping[ChunkKey, ChunkEntry]:
"""Replace all chunk keys with keys which have been offset along one axis."""

def offset_key(key: ChunkKey, axis: int, offset: int) -> ChunkKey:
inds = split(key)
inds[axis] += offset
return join(inds)

return {offset_key(k, axis, offset): v for k, v in entries.items()}


def stack_manifests(manifests: List[ChunkManifest], axis: int) -> "ChunkManifest":
"""
Stack manifests along a new dimension.
This only requires inserting one index into all chunk keys to add a new dimension.
Note axis is not expected to be negative.
"""

# even if there is only one manifest it still needs a new axis inserted
chunk_indexes_along_new_dim = range(len(manifests))
new_entries = [
insert_new_axis_into_chunk_keys(manifest.entries, axis, new_index_value)
for manifest, new_index_value in zip(manifests, chunk_indexes_along_new_dim)
]
merged_entries = dict((k, v) for d in new_entries for k, v in d.items())

# Arguably don't need to re-perform validation checks on a manifest we created out of already-validated manifests
# Could use pydantic's model_construct classmethod to skip these checks
# But we should actually performance test it because it might be pointless, and current implementation is safer
return ChunkManifest(entries=merged_entries)


def insert_new_axis_into_chunk_keys(
entries: Mapping[ChunkKey, ChunkEntry], axis: int, new_index_value: int
) -> Mapping[ChunkKey, ChunkEntry]:
"""Replace all chunk keys with keys which have a new axis inserted, with a given value."""

def insert_axis(key: ChunkKey, new_axis: int, index_value: int) -> ChunkKey:
inds = split(key)
inds.insert(new_axis, index_value)
return join(inds)

return {insert_axis(k, axis, new_index_value): v for k, v in entries.items()}
28 changes: 15 additions & 13 deletions virtualizarr/tests/test_manifests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_create_manifestarray(self):
"0.1.0": {"path": "s3://bucket/foo.nc", "offset": 300, "length": 100},
"0.1.1": {"path": "s3://bucket/foo.nc", "offset": 400, "length": 100},
}
manifest = ChunkManifest(entries=chunks_dict)
manifest = ChunkManifest.from_dict(chunks_dict)
chunks = (5, 1, 10)
shape = (5, 2, 20)
zarray = ZArray(
Expand All @@ -38,7 +38,7 @@ def test_create_invalid_manifestarray(self):
chunks_dict = {
"0.0.0": {"path": "foo.nc", "offset": 100, "length": 100},
}
manifest = ChunkManifest(entries=chunks_dict)
manifest = ChunkManifest.from_dict(chunks_dict)
chunks = (5, 1, 10)
shape = (5, 2, 20)
zarray = ZArray(
Expand Down Expand Up @@ -79,7 +79,7 @@ def test_equals(self):
"0.1.0": {"path": "s3://bucket/foo.nc", "offset": 300, "length": 100},
"0.1.1": {"path": "s3://bucket/foo.nc", "offset": 400, "length": 100},
}
manifest = ChunkManifest(entries=chunks_dict)
manifest = ChunkManifest.from_dict(chunks_dict)
chunks = (5, 1, 10)
shape = (5, 2, 20)
zarray = ZArray(
Expand Down Expand Up @@ -118,14 +118,14 @@ def test_not_equal_chunk_entries(self):
"0.0.0": {"path": "foo.nc", "offset": 100, "length": 100},
"0.0.1": {"path": "foo.nc", "offset": 200, "length": 100},
}
manifest1 = ChunkManifest(entries=chunks_dict1)
manifest1 = ChunkManifest.from_dict(chunks_dict1)
marr1 = ManifestArray(zarray=zarray, chunkmanifest=manifest1)

chunks_dict2 = {
"0.0.0": {"path": "foo.nc", "offset": 300, "length": 100},
"0.0.1": {"path": "foo.nc", "offset": 400, "length": 100},
}
manifest2 = ChunkManifest(entries=chunks_dict2)
manifest2 = ChunkManifest.from_dict(chunks_dict2)
marr2 = ManifestArray(zarray=zarray, chunkmanifest=manifest2)
assert not (marr1 == marr2).all()

Expand Down Expand Up @@ -154,14 +154,14 @@ def test_concat(self):
"0.0.0": {"path": "foo.nc", "offset": 100, "length": 100},
"0.0.1": {"path": "foo.nc", "offset": 200, "length": 100},
}
manifest1 = ChunkManifest(entries=chunks_dict1)
manifest1 = ChunkManifest.from_dict(chunks_dict1)
marr1 = ManifestArray(zarray=zarray, chunkmanifest=manifest1)

chunks_dict2 = {
"0.0.0": {"path": "foo.nc", "offset": 300, "length": 100},
"0.0.1": {"path": "foo.nc", "offset": 400, "length": 100},
}
manifest2 = ChunkManifest(entries=chunks_dict2)
manifest2 = ChunkManifest.from_dict(chunks_dict2)
marr2 = ManifestArray(zarray=zarray, chunkmanifest=manifest2)

result = np.concatenate([marr1, marr2], axis=1)
Expand Down Expand Up @@ -199,14 +199,14 @@ def test_stack(self):
"0.0": {"path": "foo.nc", "offset": 100, "length": 100},
"0.1": {"path": "foo.nc", "offset": 200, "length": 100},
}
manifest1 = ChunkManifest(entries=chunks_dict1)
manifest1 = ChunkManifest.from_dict(chunks_dict1)
marr1 = ManifestArray(zarray=zarray, chunkmanifest=manifest1)

chunks_dict2 = {
"0.0": {"path": "foo.nc", "offset": 300, "length": 100},
"0.1": {"path": "foo.nc", "offset": 400, "length": 100},
}
manifest2 = ChunkManifest(entries=chunks_dict2)
manifest2 = ChunkManifest.from_dict(chunks_dict2)
marr2 = ManifestArray(zarray=zarray, chunkmanifest=manifest2)

result = np.stack([marr1, marr2], axis=1)
Expand Down Expand Up @@ -242,28 +242,30 @@ def test_refuse_combine():
chunks_dict1 = {
"0.0.0": {"path": "foo.nc", "offset": 100, "length": 100},
}
chunkmanifest1 = ChunkManifest.from_dict(chunks_dict1)
chunks_dict2 = {
"0.0.0": {"path": "foo.nc", "offset": 300, "length": 100},
}
marr1 = ManifestArray(zarray=zarray_common, chunkmanifest=chunks_dict1)
chunkmanifest2 = ChunkManifest.from_dict(chunks_dict2)
marr1 = ManifestArray(zarray=zarray_common, chunkmanifest=chunkmanifest1)

zarray_wrong_compressor = zarray_common.copy()
zarray_wrong_compressor["compressor"] = None
marr2 = ManifestArray(zarray=zarray_wrong_compressor, chunkmanifest=chunks_dict2)
marr2 = ManifestArray(zarray=zarray_wrong_compressor, chunkmanifest=chunkmanifest2)
for func in [np.concatenate, np.stack]:
with pytest.raises(NotImplementedError, match="different codecs"):
func([marr1, marr2], axis=0)

zarray_wrong_dtype = zarray_common.copy()
zarray_wrong_dtype["dtype"] = np.dtype("int64")
marr2 = ManifestArray(zarray=zarray_wrong_dtype, chunkmanifest=chunks_dict2)
marr2 = ManifestArray(zarray=zarray_wrong_dtype, chunkmanifest=chunkmanifest2)
for func in [np.concatenate, np.stack]:
with pytest.raises(ValueError, match="inconsistent dtypes"):
func([marr1, marr2], axis=0)

zarray_wrong_dtype = zarray_common.copy()
zarray_wrong_dtype["dtype"] = np.dtype("int64")
marr2 = ManifestArray(zarray=zarray_wrong_dtype, chunkmanifest=chunks_dict2)
marr2 = ManifestArray(zarray=zarray_wrong_dtype, chunkmanifest=chunkmanifest2)
for func in [np.concatenate, np.stack]:
with pytest.raises(ValueError, match="inconsistent dtypes"):
func([marr1, marr2], axis=0)
11 changes: 8 additions & 3 deletions virtualizarr/tests/test_manifests/test_manifest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest

from virtualizarr.manifests import ChunkManifest, concat_manifests, stack_manifests
from virtualizarr.manifests import ChunkManifest


class TestCreateManifest:
Expand Down Expand Up @@ -108,7 +109,10 @@ def test_concat(self):
}
)

result = concat_manifests([manifest1, manifest2], axis=axis)
result_manifest = np.concatenate(
[manifest1.entries, manifest2.entries], axis=axis
)
result = ChunkManifest(entries=result_manifest)
assert result.dict() == expected.dict()

def test_stack(self):
Expand All @@ -134,7 +138,8 @@ def test_stack(self):
}
)

result = stack_manifests([manifest1, manifest2], axis=axis)
result_manifest = np.stack([manifest1.entries, manifest2.entries], axis=axis)
result = ChunkManifest(entries=result_manifest)
assert result.dict() == expected.dict()


Expand Down

0 comments on commit be8af12

Please sign in to comment.