Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide file hashes in the URLs to avoid unnecessary file downloads (bandwidth saver) #1433

Merged
merged 15 commits into from
Sep 23, 2023
Merged
40 changes: 24 additions & 16 deletions s3_management/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


S3 = boto3.resource('s3')
CLIENT = boto3.client('s3')
BUCKET = S3.Bucket('pytorch')

ACCEPTED_FILE_EXTENSIONS = ("whl", "zip", "tar.gz")
Expand Down Expand Up @@ -121,8 +120,8 @@ def between_bad_dates(package_build_time: datetime):


class S3Index:
def __init__(self: S3IndexType, objects: List[str], prefix: str) -> None:
self.objects = objects
def __init__(self: S3IndexType, objects: Dict[str, str], prefix: str) -> None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The RHS should be str | None

Suggested change
def __init__(self: S3IndexType, objects: Dict[str, str], prefix: str) -> None:
def __init__(self: S3IndexType, objects: Dict[str, str | None], prefix: str) -> None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's a couple other places this should be done as well.

self.objects = objects # s3 key to checksum mapping
self.prefix = prefix.rstrip("/")
self.html_name = PREFIXES_WITH_HTML[self.prefix]
# should dynamically grab subdirectories like whl/test/cu101
Expand All @@ -146,7 +145,7 @@ def nightly_packages_to_show(self: S3IndexType) -> Set[str]:
# also includes versions without GPU specifier (i.e. cu102) for easier
# sorting, sorts in reverse to put the most recent versions first
all_sorted_packages = sorted(
{self.normalize_package_version(obj) for obj in self.objects},
{self.normalize_package_version(s3_key) for s3_key in self.objects.keys()},

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flake8-simplify (link) suggests just iterating over the dict instead of explicitly calling .keys(). We could talk about the likely very itty bitty performance bump, but lol.

Is there a reason you prefer .keys()?

key=lambda name_ver: parse(name_ver.split('-', 1)[-1]),
reverse=True,
)
Expand All @@ -166,10 +165,12 @@ def nightly_packages_to_show(self: S3IndexType) -> Set[str]:
to_hide.add(obj)
else:
packages[package_name] += 1
return set(self.objects).difference({
obj for obj in self.objects
if self.normalize_package_version(obj) in to_hide
})
nightly_packages = {}
for obj, checksum in self.objects.items():
normalized_package_version = self.normalize_package_version(obj)
if not normalized_package_version in to_hide:
nightly_packages[normalized_package_version] = checksum
return nightly_packages
matteius marked this conversation as resolved.
Show resolved Hide resolved

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also the return type needs updating


def is_obj_at_root(self, obj:str) -> bool:
return path.dirname(obj) == self.prefix
Expand All @@ -190,15 +191,15 @@ def gen_file_list(
else self.objects
)
subdir = self._resolve_subdir(subdir) + '/'
for obj in objects:
for obj, checksum in objects.items():
if package_name is not None:
if self.obj_to_package_name(obj) != package_name:
continue
if self.is_obj_at_root(obj) or obj.startswith(subdir):
yield obj
yield obj, checksum

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type of this function needs updating

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is fixed now.


def get_package_names(self, subdir: Optional[str] = None) -> List[str]:
return sorted(set(self.obj_to_package_name(obj) for obj in self.gen_file_list(subdir)))
return sorted(set(self.obj_to_package_name(obj) for obj, _ in self.gen_file_list(subdir)))

def normalize_package_version(self: S3IndexType, obj: str) -> str:
# removes the GPU specifier from the package name as well as
Expand Down Expand Up @@ -226,7 +227,7 @@ def to_legacy_html(
out: List[str] = []
subdir = self._resolve_subdir(subdir)
is_root = subdir == self.prefix
for obj in self.gen_file_list(subdir):
for obj, _ in self.gen_file_list(subdir):
matteius marked this conversation as resolved.
Show resolved Hide resolved
matteius marked this conversation as resolved.
Show resolved Hide resolved
# Strip our prefix
sanitized_obj = obj.replace(subdir, "", 1)
if sanitized_obj.startswith('/'):
Expand Down Expand Up @@ -254,8 +255,11 @@ def to_simple_package_html(
out.append('<html>')
out.append(' <body>')
out.append(' <h1>Links for {}</h1>'.format(package_name.lower().replace("_","-")))
for obj in sorted(self.gen_file_list(subdir, package_name)):
out.append(f' <a href="/{obj}">{path.basename(obj).replace("%2B","+")}</a><br/>')
for obj, checksum in sorted(self.gen_file_list(subdir, package_name)):
if checksum:
out.append(f' <a href="/{obj}#sha256={checksum}">{path.basename(obj).replace("%2B","+")}</a><br/>')
else:
out.append(f' <a href="/{obj}">{path.basename(obj).replace("%2B","+")}</a><br/>')
matteius marked this conversation as resolved.
Show resolved Hide resolved
# Adding html footer
out.append(' </body>')
out.append('</html>')
Expand Down Expand Up @@ -316,7 +320,6 @@ def upload_pep503_htmls(self) -> None:
Body=self.to_simple_package_html(subdir=subdir, package_name=pkg_name)
)


def save_legacy_html(self) -> None:
for subdir in self.subdirs:
print(f"INFO Saving {subdir}/{self.html_name}")
Expand Down Expand Up @@ -348,10 +351,13 @@ def from_S3(cls: Type[S3IndexType], prefix: str) -> S3IndexType:
for pattern in ACCEPTED_SUBDIR_PATTERNS
]) and obj.key.endswith(ACCEPTED_FILE_EXTENSIONS)
if is_acceptable:
response = obj.meta.client.head_object(Bucket=BUCKET.name, Key=obj.key, ChecksumMode="ENABLED")
sha256 = response.get("ChecksumSHA256")
matteius marked this conversation as resolved.
Show resolved Hide resolved
matteius marked this conversation as resolved.
Show resolved Hide resolved
sanitized_key = obj.key.replace("+", "%2B")
objects.append(sanitized_key)
objects.append((sanitized_key, sha256))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still has objects as a list. I don't see it being converted to a dict anywhere?

I think you can just change it on line 343 to a dict and here to objects[sanitized_key] = sha256

return cls(objects, prefix)


def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser("Manage S3 HTML indices for PyTorch")
parser.add_argument(
Expand All @@ -363,6 +369,7 @@ def create_parser() -> argparse.ArgumentParser:
parser.add_argument("--generate-pep503", action="store_true")
return parser


def main():
parser = create_parser()
args = parser.parse_args()
Expand All @@ -387,5 +394,6 @@ def main():
if args.generate_pep503:
idx.upload_pep503_htmls()


if __name__ == "__main__":
main()