Skip to content

Commit

Permalink
Add manifest.Shard to replace triple. (#255)
Browse files Browse the repository at this point in the history
* Add `manifest.Shard` to replace triple.

This would be used in a subsequent PR and also simplifies tests a little.

Most changes are due to the change from a tuple to a dataclass.

Signed-off-by: Mihai Maruseac <[email protected]>

* Remove unused import

Signed-off-by: Mihai Maruseac <[email protected]>

---------

Signed-off-by: Mihai Maruseac <[email protected]>
  • Loading branch information
mihaimaruseac authored Jul 24, 2024
1 parent 6f01724 commit c3c4110
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 40 deletions.
28 changes: 25 additions & 3 deletions model_signing/manifest/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,28 @@ def __eq__(self, other: Self):
return self._item_to_digest == other._item_to_digest


@dataclasses.dataclass(frozen=True, order=True)
class Shard:
"""A dataclass to hold information about a file shard.
Attributes:
path: The path to the file, relative to the model root.
start: The start offset of the shard (included).
end: The end offset of the shard (not included).
"""

path: pathlib.PurePath
start: int
end: int

def __str__(self) -> str:
"""Converts the item to a canonicalized string representation.
The format is {path}:{start}:{end}, which should also be easy to decode.
"""
return f"{str(self.path)}:{self.start}:{self.end}"


@dataclasses.dataclass
class ShardedFileManifestItem(ManifestItem):
"""A manifest item that records a file shard together with its digest."""
Expand All @@ -146,7 +168,7 @@ def __init__(
path: pathlib.PurePath,
start: int,
end: int,
digest: hashing.Digest
digest: hashing.Digest,
):
"""Builds a manifest item pairing a file shard with its digest.
Expand All @@ -163,9 +185,9 @@ def __init__(
self.digest = digest

@property
def input_tuple(self) -> tuple[pathlib.PurePath, int, int]:
def input_tuple(self) -> Shard:
"""Returns the triple that uniquely determines the manifest item."""
return (self.path, self.start, self.end)
return Shard(self.path, self.start, self.end)


class ShardLevelManifest(FileLevelManifest):
Expand Down
65 changes: 28 additions & 37 deletions model_signing/serialization/serialize_by_file_shard_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
pytest model_signing/ --update_goldens
"""

import dataclasses
import pathlib
import pytest

Expand Down Expand Up @@ -301,32 +300,22 @@ def test_shard_size_changes_digests(self, sample_model_folder):
assert manifest1.digest.digest_value != manifest2.digest.digest_value


@dataclasses.dataclass(frozen=True, order=True)
class _Shard:
"""A shard of a file from a manifest."""

path: str
start: int
end: int


def _extract_shard_items_from_manifest(
manifest: manifest.ShardLevelManifest,
) -> dict[_Shard, str]:
) -> dict[manifest.Shard, str]:
"""Builds a dictionary representation of the items in a manifest.
Every item is mapped to its digest.
Used in multiple tests to check that we obtained the expected manifest.
"""
return {
# convert to file path (relative to model) string and endpoints
_Shard(str(shard[0]), shard[1], shard[2]): digest.digest_hex
shard: digest.digest_hex
for shard, digest in manifest._item_to_digest.items()
}


def _parse_shard_and_digest(line: str) -> tuple[_Shard, str]:
def _parse_shard_and_digest(line: str) -> tuple[manifest.Shard, str]:
"""Reads a file shard and its digest from a line in the golden file.
Args:
Expand All @@ -336,7 +325,7 @@ def _parse_shard_and_digest(line: str) -> tuple[_Shard, str]:
The shard tuple and the digest corresponding to the line that was read.
"""
path, start, end, digest = line.strip().split(":")
shard = _Shard(path, int(start), int(end))
shard = manifest.Shard(pathlib.PurePosixPath(path), int(start), int(end))
return shard, digest


Expand Down Expand Up @@ -370,18 +359,16 @@ def test_known_models(self, request, model_fixture_name):
serializer = serialize_by_file_shard.ManifestSerializer(
self._hasher_factory
)
manifest = serializer.serialize(model)
items = _extract_shard_items_from_manifest(manifest)
manifest_file = serializer.serialize(model)
items = _extract_shard_items_from_manifest(manifest_file)

# Compare with golden, or write to golden (approximately "assert")
if should_update:
with open(golden_path, "w", encoding="utf-8") as f:
for shard, digest in sorted(items.items()):
f.write(
f"{shard.path}:{shard.start}:{shard.end}:{digest}\n"
)
f.write(f"{shard}:{digest}\n")
else:
found_items: dict[_Shard, str] = {}
found_items: dict[manifest.Shard, str] = {}
with open(golden_path, "r", encoding="utf-8") as f:
for line in f:
shard, digest = _parse_shard_and_digest(line)
Expand All @@ -403,18 +390,16 @@ def test_known_models_small_shards(self, request, model_fixture_name):
serializer = serialize_by_file_shard.ManifestSerializer(
self._hasher_factory_small_shards
)
manifest = serializer.serialize(model)
items = _extract_shard_items_from_manifest(manifest)
manifest_file = serializer.serialize(model)
items = _extract_shard_items_from_manifest(manifest_file)

# Compare with golden, or write to golden (approximately "assert")
if should_update:
with open(golden_path, "w", encoding="utf-8") as f:
for shard, digest in sorted(items.items()):
f.write(
f"{shard.path}:{shard.start}:{shard.end}:{digest}\n"
)
f.write(f"{shard}:{digest}\n")
else:
found_items: dict[_Shard, str] = {}
found_items: dict[manifest.Shard, str] = {}
with open(golden_path, "r", encoding="utf-8") as f:
for line in f:
shard, digest = _parse_shard_and_digest(line)
Expand Down Expand Up @@ -522,9 +507,8 @@ def _check_manifests_match_except_on_renamed_file(
old_manifest._item_to_digest
)
for shard, digest in new_manifest._item_to_digest.items():
path, start, end = shard
if path.name == new_name:
old_shard = (old_name, start, end)
if shard.path.name == new_name:
old_shard = manifest.Shard(old_name, shard.start, shard.end)
assert old_manifest._item_to_digest[old_shard] == digest
else:
assert old_manifest._item_to_digest[shard] == digest
Expand Down Expand Up @@ -566,13 +550,14 @@ def _check_manifests_match_except_on_renamed_dir(
old_manifest._item_to_digest
)
for shard, digest in new_manifest._item_to_digest.items():
path, start, end = shard
if new_name in path.parts:
if new_name in shard.path.parts:
parts = [
old_name if part == new_name else part
for part in path.parts
for part in shard.path.parts
]
old = (pathlib.PurePosixPath(*parts), start, end)
old = manifest.Shard(
pathlib.PurePosixPath(*parts), shard.start, shard.end
)
assert old_manifest._item_to_digest[old] == digest
else:
assert old_manifest._item_to_digest[shard] == digest
Expand Down Expand Up @@ -627,10 +612,10 @@ def _check_manifests_match_except_on_entry(
old_manifest._item_to_digest
)
for shard, digest in new_manifest._item_to_digest.items():
path, _, _ = shard
if path == expected_mismatch_path:
if shard.path == expected_mismatch_path:
# Note that the file size changes
assert old_manifest._item_to_digest[(path, 0, 23)] != digest
item = manifest.Shard(shard.path, 0, 23)
assert old_manifest._item_to_digest[item] != digest
else:
assert old_manifest._item_to_digest[shard] == digest

Expand Down Expand Up @@ -668,3 +653,9 @@ def test_max_workers_does_not_change_digest(self, sample_model_folder):

assert manifest1 == manifest2
assert manifest1 == manifest3


def test_shard_to_string(self):
"""Ensure the shard's `__str__` method behaves as assumed."""
shard = manifest.Shard(pathlib.PurePosixPath("a"), 0, 42)
assert str(shard) == "a:0:42"

0 comments on commit c3c4110

Please sign in to comment.