diff --git a/virtualizarr/manifests/__init__.py b/virtualizarr/manifests/__init__.py index 39f60c59..c317ed6a 100644 --- a/virtualizarr/manifests/__init__.py +++ b/virtualizarr/manifests/__init__.py @@ -2,9 +2,4 @@ # This is just to avoid conflicting with some type of file called manifest that .gitignore recommends ignoring. from .array import ManifestArray # type: ignore # noqa -from .manifest import ( # type: ignore # noqa - ChunkEntry, - ChunkManifest, - concat_manifests, - stack_manifests, -) +from .manifest import ChunkEntry, ChunkManifest # type: ignore # noqa diff --git a/virtualizarr/manifests/array_api.py b/virtualizarr/manifests/array_api.py index 9333d88e..4bfea632 100644 --- a/virtualizarr/manifests/array_api.py +++ b/virtualizarr/manifests/array_api.py @@ -4,7 +4,7 @@ import numpy as np from ..zarr import Codec, ZArray -from .manifest import concat_manifests, stack_manifests +from .manifest import ChunkManifest if TYPE_CHECKING: from .array import ManifestArray @@ -123,10 +123,11 @@ def concatenate( new_shape = list(first_shape) new_shape[axis] = new_length_along_concat_axis - concatenated_manifest = concat_manifests( - [arr.manifest for arr in arrays], + concatenated_manifest_entries = np.concatenate( + [arr.manifest.entries for arr in arrays], axis=axis, ) + concatenated_manifest = ChunkManifest(entries=concatenated_manifest_entries) new_zarray = ZArray( chunks=first_arr.chunks, @@ -206,10 +207,11 @@ def stack( new_shape = list(first_shape) new_shape.insert(axis, length_along_new_stacked_axis) - stacked_manifest = stack_manifests( - [arr.manifest for arr in arrays], + stacked_manifest_entries = np.stack( + [arr.manifest.entries for arr in arrays], axis=axis, ) + stacked_manifest = ChunkManifest(entries=stacked_manifest_entries) # chunk size has changed because a length-1 axis has been inserted old_chunks = first_arr.chunks diff --git a/virtualizarr/manifests/manifest.py b/virtualizarr/manifests/manifest.py index fb14c297..456a1333 100644 --- a/virtualizarr/manifests/manifest.py +++ b/virtualizarr/manifests/manifest.py @@ -1,5 +1,5 @@ import re -from typing import Any, Iterable, Iterator, List, Mapping, NewType, Tuple, Union, cast +from typing import Any, Iterable, Iterator, List, NewType, Tuple, Union, cast import numpy as np from pydantic import BaseModel, ConfigDict @@ -13,9 +13,7 @@ _CHUNK_KEY = rf"^{_INTEGER}+({_SEPARATOR}{_INTEGER})*$" # matches 1 integer, optionally followed by more integers each separated by a separator (i.e. a period) -ChunkDict = NewType( - "ChunkDict", dict[ChunkKey, dict[str, Union[str, int]]] -) # just the .zattrs (for one array or for the whole store/group) +ChunkDict = NewType("ChunkDict", dict[ChunkKey, dict[str, Union[str, int]]]) class ChunkEntry(BaseModel): @@ -190,7 +188,7 @@ def from_kerchunk_chunk_dict(cls, kerchunk_chunk_dict) -> "ChunkManifest": chunkentries = { k: ChunkEntry.from_kerchunk(v) for k, v in kerchunk_chunk_dict.items() } - return ChunkManifest(entries=chunkentries) + return ChunkManifest.from_dict(chunkentries) def split(key: ChunkKey) -> Tuple[int, ...]: @@ -230,81 +228,3 @@ def get_chunk_grid_shape(chunk_keys: Iterable[ChunkKey]) -> Tuple[int, ...]: max(indices_along_one_dim) + 1 for indices_along_one_dim in zipped_indices ) return chunk_grid_shape - - -def concat_manifests(manifests: List["ChunkManifest"], axis: int) -> "ChunkManifest": - """ - Concatenate manifests along an existing dimension. - - This only requires adjusting one index of chunk keys along a single dimension. - - Note axis is not expected to be negative. - """ - if len(manifests) == 1: - return manifests[0] - - chunk_grid_shapes = [manifest.shape_chunk_grid for manifest in manifests] - lengths_along_concat_dim = [shape[axis] for shape in chunk_grid_shapes] - - # Note we do not need to change the keys of the first manifest - chunk_index_offsets = np.cumsum(lengths_along_concat_dim)[:-1] - new_entries = [ - adjust_chunk_keys(manifest.entries, axis, offset) - for manifest, offset in zip(manifests[1:], chunk_index_offsets) - ] - all_entries = [manifests[0].entries] + new_entries - merged_entries = dict((k, v) for d in all_entries for k, v in d.items()) - - # Arguably don't need to re-perform validation checks on a manifest we created out of already-validated manifests - # Could use pydantic's model_construct classmethod to skip these checks - # But we should actually performance test it because it might be pointless, and current implementation is safer - return ChunkManifest(entries=merged_entries) - - -def adjust_chunk_keys( - entries: Mapping[ChunkKey, ChunkEntry], axis: int, offset: int -) -> Mapping[ChunkKey, ChunkEntry]: - """Replace all chunk keys with keys which have been offset along one axis.""" - - def offset_key(key: ChunkKey, axis: int, offset: int) -> ChunkKey: - inds = split(key) - inds[axis] += offset - return join(inds) - - return {offset_key(k, axis, offset): v for k, v in entries.items()} - - -def stack_manifests(manifests: List[ChunkManifest], axis: int) -> "ChunkManifest": - """ - Stack manifests along a new dimension. - - This only requires inserting one index into all chunk keys to add a new dimension. - - Note axis is not expected to be negative. - """ - - # even if there is only one manifest it still needs a new axis inserted - chunk_indexes_along_new_dim = range(len(manifests)) - new_entries = [ - insert_new_axis_into_chunk_keys(manifest.entries, axis, new_index_value) - for manifest, new_index_value in zip(manifests, chunk_indexes_along_new_dim) - ] - merged_entries = dict((k, v) for d in new_entries for k, v in d.items()) - - # Arguably don't need to re-perform validation checks on a manifest we created out of already-validated manifests - # Could use pydantic's model_construct classmethod to skip these checks - # But we should actually performance test it because it might be pointless, and current implementation is safer - return ChunkManifest(entries=merged_entries) - - -def insert_new_axis_into_chunk_keys( - entries: Mapping[ChunkKey, ChunkEntry], axis: int, new_index_value: int -) -> Mapping[ChunkKey, ChunkEntry]: - """Replace all chunk keys with keys which have a new axis inserted, with a given value.""" - - def insert_axis(key: ChunkKey, new_axis: int, index_value: int) -> ChunkKey: - inds = split(key) - inds.insert(new_axis, index_value) - return join(inds) - - return {insert_axis(k, axis, new_index_value): v for k, v in entries.items()} diff --git a/virtualizarr/tests/test_manifests/test_array.py b/virtualizarr/tests/test_manifests/test_array.py index aa10d91c..b587f88a 100644 --- a/virtualizarr/tests/test_manifests/test_array.py +++ b/virtualizarr/tests/test_manifests/test_array.py @@ -13,7 +13,7 @@ def test_create_manifestarray(self): "0.1.0": {"path": "s3://bucket/foo.nc", "offset": 300, "length": 100}, "0.1.1": {"path": "s3://bucket/foo.nc", "offset": 400, "length": 100}, } - manifest = ChunkManifest(entries=chunks_dict) + manifest = ChunkManifest.from_dict(chunks_dict) chunks = (5, 1, 10) shape = (5, 2, 20) zarray = ZArray( @@ -38,7 +38,7 @@ def test_create_invalid_manifestarray(self): chunks_dict = { "0.0.0": {"path": "foo.nc", "offset": 100, "length": 100}, } - manifest = ChunkManifest(entries=chunks_dict) + manifest = ChunkManifest.from_dict(chunks_dict) chunks = (5, 1, 10) shape = (5, 2, 20) zarray = ZArray( @@ -79,7 +79,7 @@ def test_equals(self): "0.1.0": {"path": "s3://bucket/foo.nc", "offset": 300, "length": 100}, "0.1.1": {"path": "s3://bucket/foo.nc", "offset": 400, "length": 100}, } - manifest = ChunkManifest(entries=chunks_dict) + manifest = ChunkManifest.from_dict(chunks_dict) chunks = (5, 1, 10) shape = (5, 2, 20) zarray = ZArray( @@ -118,14 +118,14 @@ def test_not_equal_chunk_entries(self): "0.0.0": {"path": "foo.nc", "offset": 100, "length": 100}, "0.0.1": {"path": "foo.nc", "offset": 200, "length": 100}, } - manifest1 = ChunkManifest(entries=chunks_dict1) + manifest1 = ChunkManifest.from_dict(chunks_dict1) marr1 = ManifestArray(zarray=zarray, chunkmanifest=manifest1) chunks_dict2 = { "0.0.0": {"path": "foo.nc", "offset": 300, "length": 100}, "0.0.1": {"path": "foo.nc", "offset": 400, "length": 100}, } - manifest2 = ChunkManifest(entries=chunks_dict2) + manifest2 = ChunkManifest.from_dict(chunks_dict2) marr2 = ManifestArray(zarray=zarray, chunkmanifest=manifest2) assert not (marr1 == marr2).all() @@ -154,14 +154,14 @@ def test_concat(self): "0.0.0": {"path": "foo.nc", "offset": 100, "length": 100}, "0.0.1": {"path": "foo.nc", "offset": 200, "length": 100}, } - manifest1 = ChunkManifest(entries=chunks_dict1) + manifest1 = ChunkManifest.from_dict(chunks_dict1) marr1 = ManifestArray(zarray=zarray, chunkmanifest=manifest1) chunks_dict2 = { "0.0.0": {"path": "foo.nc", "offset": 300, "length": 100}, "0.0.1": {"path": "foo.nc", "offset": 400, "length": 100}, } - manifest2 = ChunkManifest(entries=chunks_dict2) + manifest2 = ChunkManifest.from_dict(chunks_dict2) marr2 = ManifestArray(zarray=zarray, chunkmanifest=manifest2) result = np.concatenate([marr1, marr2], axis=1) @@ -199,14 +199,14 @@ def test_stack(self): "0.0": {"path": "foo.nc", "offset": 100, "length": 100}, "0.1": {"path": "foo.nc", "offset": 200, "length": 100}, } - manifest1 = ChunkManifest(entries=chunks_dict1) + manifest1 = ChunkManifest.from_dict(chunks_dict1) marr1 = ManifestArray(zarray=zarray, chunkmanifest=manifest1) chunks_dict2 = { "0.0": {"path": "foo.nc", "offset": 300, "length": 100}, "0.1": {"path": "foo.nc", "offset": 400, "length": 100}, } - manifest2 = ChunkManifest(entries=chunks_dict2) + manifest2 = ChunkManifest.from_dict(chunks_dict2) marr2 = ManifestArray(zarray=zarray, chunkmanifest=manifest2) result = np.stack([marr1, marr2], axis=1) @@ -242,28 +242,30 @@ def test_refuse_combine(): chunks_dict1 = { "0.0.0": {"path": "foo.nc", "offset": 100, "length": 100}, } + chunkmanifest1 = ChunkManifest.from_dict(chunks_dict1) chunks_dict2 = { "0.0.0": {"path": "foo.nc", "offset": 300, "length": 100}, } - marr1 = ManifestArray(zarray=zarray_common, chunkmanifest=chunks_dict1) + chunkmanifest2 = ChunkManifest.from_dict(chunks_dict2) + marr1 = ManifestArray(zarray=zarray_common, chunkmanifest=chunkmanifest1) zarray_wrong_compressor = zarray_common.copy() zarray_wrong_compressor["compressor"] = None - marr2 = ManifestArray(zarray=zarray_wrong_compressor, chunkmanifest=chunks_dict2) + marr2 = ManifestArray(zarray=zarray_wrong_compressor, chunkmanifest=chunkmanifest2) for func in [np.concatenate, np.stack]: with pytest.raises(NotImplementedError, match="different codecs"): func([marr1, marr2], axis=0) zarray_wrong_dtype = zarray_common.copy() zarray_wrong_dtype["dtype"] = np.dtype("int64") - marr2 = ManifestArray(zarray=zarray_wrong_dtype, chunkmanifest=chunks_dict2) + marr2 = ManifestArray(zarray=zarray_wrong_dtype, chunkmanifest=chunkmanifest2) for func in [np.concatenate, np.stack]: with pytest.raises(ValueError, match="inconsistent dtypes"): func([marr1, marr2], axis=0) zarray_wrong_dtype = zarray_common.copy() zarray_wrong_dtype["dtype"] = np.dtype("int64") - marr2 = ManifestArray(zarray=zarray_wrong_dtype, chunkmanifest=chunks_dict2) + marr2 = ManifestArray(zarray=zarray_wrong_dtype, chunkmanifest=chunkmanifest2) for func in [np.concatenate, np.stack]: with pytest.raises(ValueError, match="inconsistent dtypes"): func([marr1, marr2], axis=0) diff --git a/virtualizarr/tests/test_manifests/test_manifest.py b/virtualizarr/tests/test_manifests/test_manifest.py index 38b16f09..062ad7ad 100644 --- a/virtualizarr/tests/test_manifests/test_manifest.py +++ b/virtualizarr/tests/test_manifests/test_manifest.py @@ -1,6 +1,7 @@ +import numpy as np import pytest -from virtualizarr.manifests import ChunkManifest, concat_manifests, stack_manifests +from virtualizarr.manifests import ChunkManifest class TestCreateManifest: @@ -108,7 +109,10 @@ def test_concat(self): } ) - result = concat_manifests([manifest1, manifest2], axis=axis) + result_manifest = np.concatenate( + [manifest1.entries, manifest2.entries], axis=axis + ) + result = ChunkManifest(entries=result_manifest) assert result.dict() == expected.dict() def test_stack(self): @@ -134,7 +138,8 @@ def test_stack(self): } ) - result = stack_manifests([manifest1, manifest2], axis=axis) + result_manifest = np.stack([manifest1.entries, manifest2.entries], axis=axis) + result = ChunkManifest(entries=result_manifest) assert result.dict() == expected.dict()