Skip to content

Commit

Permalink
Provide file hashes in the URLs to avoid unnecessary file downloads (…
Browse files Browse the repository at this point in the history
…bandwidth saver) (#1433)

Supply sha256 query parameters using boto3 to avoid hundreds of extra Gigabytes of downloads each day during pipenv and poetry resolution lock cycles.

Fixes point 1 in pytorch/pytorch#76557
Fixes #1347
  • Loading branch information
matteius authored Sep 23, 2023
1 parent 553b4df commit dbad8b7
Showing 1 changed file with 44 additions and 15 deletions.
59 changes: 44 additions & 15 deletions s3_management/manage.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
#!/usr/bin/env python

import argparse
import base64
import dataclasses
import functools
import time

from os import path, makedirs
from datetime import datetime
from collections import defaultdict
from typing import Iterator, List, Type, Dict, Set, TypeVar, Optional
from typing import Iterable, List, Type, Dict, Set, TypeVar, Optional
from re import sub, match, search
from packaging.version import parse

import boto3


S3 = boto3.resource('s3')
CLIENT = boto3.client('s3')
BUCKET = S3.Bucket('pytorch')

ACCEPTED_FILE_EXTENSIONS = ("whl", "zip", "tar.gz")
Expand Down Expand Up @@ -107,6 +109,23 @@

S3IndexType = TypeVar('S3IndexType', bound='S3Index')


@dataclasses.dataclass(frozen=True)
@functools.total_ordering
class S3Object:
key: str
checksum: str | None

def __str__(self):
return self.key

def __eq__(self, other):
return self.key == other.key

def __lt__(self, other):
return self.key < other.key


def extract_package_build_time(full_package_name: str) -> datetime:
result = search(PACKAGE_DATE_REGEX, full_package_name)
if result is not None:
Expand All @@ -124,7 +143,7 @@ def between_bad_dates(package_build_time: datetime):


class S3Index:
def __init__(self: S3IndexType, objects: List[str], prefix: str) -> None:
def __init__(self: S3IndexType, objects: List[S3Object], prefix: str) -> None:
self.objects = objects
self.prefix = prefix.rstrip("/")
self.html_name = PREFIXES_WITH_HTML[self.prefix]
Expand All @@ -134,7 +153,7 @@ def __init__(self: S3IndexType, objects: List[str], prefix: str) -> None:
path.dirname(obj) for obj in objects if path.dirname != prefix
}

def nightly_packages_to_show(self: S3IndexType) -> Set[str]:
def nightly_packages_to_show(self: S3IndexType) -> Set[S3Object]:
"""Finding packages to show based on a threshold we specify
Basically takes our S3 packages, normalizes the version for easier
Expand Down Expand Up @@ -174,8 +193,8 @@ def nightly_packages_to_show(self: S3IndexType) -> Set[str]:
if self.normalize_package_version(obj) in to_hide
})

def is_obj_at_root(self, obj:str) -> bool:
return path.dirname(obj) == self.prefix
def is_obj_at_root(self, obj: S3Object) -> bool:
return path.dirname(str(obj)) == self.prefix

def _resolve_subdir(self, subdir: Optional[str] = None) -> str:
if not subdir:
Expand All @@ -187,7 +206,7 @@ def gen_file_list(
self,
subdir: Optional[str]=None,
package_name: Optional[str] = None
) -> Iterator[str]:
) -> Iterable[S3Object]:
objects = (
self.nightly_packages_to_show() if self.prefix == 'whl/nightly'
else self.objects
Expand All @@ -197,23 +216,23 @@ def gen_file_list(
if package_name is not None:
if self.obj_to_package_name(obj) != package_name:
continue
if self.is_obj_at_root(obj) or obj.startswith(subdir):
if self.is_obj_at_root(obj) or str(obj).startswith(subdir):
yield obj

def get_package_names(self, subdir: Optional[str] = None) -> List[str]:
return sorted(set(self.obj_to_package_name(obj) for obj in self.gen_file_list(subdir)))

def normalize_package_version(self: S3IndexType, obj: str) -> str:
def normalize_package_version(self: S3IndexType, obj: S3Object) -> str:
# removes the GPU specifier from the package name as well as
# unnecessary things like the file extension, architecture name, etc.
return sub(
r"%2B.*",
"",
"-".join(path.basename(obj).split("-")[:2])
"-".join(path.basename(str(obj)).split("-")[:2])
)

def obj_to_package_name(self, obj: str) -> str:
return path.basename(obj).split('-', 1)[0]
def obj_to_package_name(self, obj: S3Object) -> str:
return path.basename(str(obj)).split('-', 1)[0]

def to_legacy_html(
self,
Expand Down Expand Up @@ -258,7 +277,8 @@ def to_simple_package_html(
out.append(' <body>')
out.append(' <h1>Links for {}</h1>'.format(package_name.lower().replace("_","-")))
for obj in sorted(self.gen_file_list(subdir, package_name)):
out.append(f' <a href="/{obj}">{path.basename(obj).replace("%2B","+")}</a><br/>')
maybe_fragment = f"#sha256={obj.checksum}" if obj.checksum else ""
out.append(f' <a href="/{obj}{maybe_fragment}">{path.basename(obj).replace("%2B","+")}</a><br/>')
# Adding html footer
out.append(' </body>')
out.append('</html>')
Expand Down Expand Up @@ -319,7 +339,6 @@ def upload_pep503_htmls(self) -> None:
Body=self.to_simple_package_html(subdir=subdir, package_name=pkg_name)
)


def save_legacy_html(self) -> None:
for subdir in self.subdirs:
print(f"INFO Saving {subdir}/{self.html_name}")
Expand Down Expand Up @@ -351,10 +370,18 @@ def from_S3(cls: Type[S3IndexType], prefix: str) -> S3IndexType:
for pattern in ACCEPTED_SUBDIR_PATTERNS
]) and obj.key.endswith(ACCEPTED_FILE_EXTENSIONS)
if is_acceptable:
# Add PEP 503-compatible hashes to URLs to allow clients to avoid spurious downloads, if possible.
response = obj.meta.client.head_object(Bucket=BUCKET.name, Key=obj.key, ChecksumMode="ENABLED")
sha256 = (_b64 := response.get("ChecksumSHA256")) and base64.b64decode(_b64).hex()
sanitized_key = obj.key.replace("+", "%2B")
objects.append(sanitized_key)
s3_object = S3Object(
key=sanitized_key,
checksum=sha256,
)
objects.append(s3_object)
return cls(objects, prefix)


def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser("Manage S3 HTML indices for PyTorch")
parser.add_argument(
Expand All @@ -366,6 +393,7 @@ def create_parser() -> argparse.ArgumentParser:
parser.add_argument("--generate-pep503", action="store_true")
return parser


def main():
parser = create_parser()
args = parser.parse_args()
Expand All @@ -390,5 +418,6 @@ def main():
if args.generate_pep503:
idx.upload_pep503_htmls()


if __name__ == "__main__":
main()

0 comments on commit dbad8b7

Please sign in to comment.