Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calculate checksum from local file if upload optimization succeeds #3968

Merged
merged 10 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 87 additions & 57 deletions api/python/quilt3/data_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@
return chunksize


def is_mpu(file_size: int) -> bool:
return file_size >= CHECKSUM_MULTIPART_THRESHOLD


_EMPTY_STRING_SHA256 = hashlib.sha256(b'').digest()


Expand Down Expand Up @@ -303,7 +307,7 @@
def _upload_file(ctx: WorkerContext, size: int, src_path: str, dest_bucket: str, dest_key: str):
s3_client = ctx.s3_client_provider.standard_client

if size < CHECKSUM_MULTIPART_THRESHOLD:
if not is_mpu(size):
with ReadFileChunk.from_filename(src_path, 0, size, [ctx.progress]) as fd:
resp = s3_client.put_object(
Body=fd,
Expand Down Expand Up @@ -460,7 +464,7 @@

s3_client = ctx.s3_client_provider.standard_client

if size < CHECKSUM_MULTIPART_THRESHOLD:
if not is_mpu(size):
params: Dict[str, Any] = dict(
CopySource=src_params,
Bucket=dest_bucket,
Expand Down Expand Up @@ -530,43 +534,62 @@
ctx.run(upload_part, i, start, end)


def _upload_or_copy_file(ctx: WorkerContext, size: int, src_path: str, dest_bucket: str, dest_path: str):
def _calculate_local_checksum(path: str, size: int):
chunksize = get_checksum_chunksize(size)

part_hashes = []
for start in range(0, size, chunksize):
end = min(start + chunksize, size)
part_hashes.append(_calculate_local_part_checksum(path, start, end - start))

return _make_checksum_from_parts(part_hashes)


def _reuse_remote_file(ctx: WorkerContext, size: int, src_path: str, dest_bucket: str, dest_path: str):
# Optimization: check if the remote file already exists and has the right ETag,
# and skip the upload.
if size >= UPLOAD_ETAG_OPTIMIZATION_THRESHOLD:
try:
params = dict(Bucket=dest_bucket, Key=dest_path)
s3_client = ctx.s3_client_provider.find_correct_client(S3Api.HEAD_OBJECT, dest_bucket, params)
resp = s3_client.head_object(**params, ChecksumMode='ENABLED')
except ClientError:
# Destination doesn't exist, so fall through to the normal upload.
pass
except S3NoValidClientError:
# S3ClientProvider can't currently distinguish between a user that has PUT but not LIST permissions and a
# user that has no permissions. If we can't find a valid client, proceed to the upload stage anyway.
pass
else:
# Check the ETag.
dest_size = resp['ContentLength']
dest_etag = resp['ETag']
dest_version_id = resp.get('VersionId')
if size == dest_size and resp.get('ServerSideEncryption') != 'aws:kms':
src_etag = _calculate_etag(src_path)
if src_etag == dest_etag:
# Nothing more to do. We should not attempt to copy the object because
# that would cause the "copy object to itself" error.
# TODO: Check SHA256 before checking ETag?
s3_checksum = resp.get('ChecksumSHA256')
if s3_checksum is None:
checksum = None
elif '-' in s3_checksum:
checksum, _ = s3_checksum.split('-', 1)
else:
checksum = _simple_s3_to_quilt_checksum(s3_checksum)
ctx.progress(size)
ctx.done(PhysicalKey(dest_bucket, dest_path, dest_version_id), checksum)
return # Optimization succeeded.
if size < UPLOAD_ETAG_OPTIMIZATION_THRESHOLD:
return None
try:
params = dict(Bucket=dest_bucket, Key=dest_path)
s3_client = ctx.s3_client_provider.find_correct_client(S3Api.HEAD_OBJECT, dest_bucket, params)
resp = s3_client.head_object(**params, ChecksumMode="ENABLED")
except ClientError:
# Destination doesn't exist, so fall through to the normal upload.
pass
except S3NoValidClientError:

Check warning on line 560 in api/python/quilt3/data_transfer.py

View check run for this annotation

Codecov / codecov/patch/informational

api/python/quilt3/data_transfer.py#L560

Added line #L560 was not covered by tests
# S3ClientProvider can't currently distinguish between a user that has PUT but not LIST permissions and a
# user that has no permissions. If we can't find a valid client, proceed to the upload stage anyway.
pass

Check warning on line 563 in api/python/quilt3/data_transfer.py

View check run for this annotation

Codecov / codecov/patch/informational

api/python/quilt3/data_transfer.py#L563

Added line #L563 was not covered by tests
else:
dest_size = resp["ContentLength"]
if dest_size != size:
return None
# TODO: we could check hashes of parts, to finish faster
s3_checksum = resp.get("ChecksumSHA256")
if s3_checksum is not None:
if "-" in s3_checksum:
checksum, num_parts_str = s3_checksum.split("-", 1)
num_parts = int(num_parts_str)
else:
checksum = _simple_s3_to_quilt_checksum(s3_checksum)
num_parts = None
expected_num_parts = math.ceil(size / get_checksum_chunksize(size)) if is_mpu(size) else None
if num_parts == expected_num_parts and checksum == _calculate_local_checksum(src_path, size):
return resp.get("VersionId"), checksum
elif resp.get("ServerSideEncryption") != "aws:kms" and resp["ETag"] == _calculate_etag(src_path):
return resp.get("VersionId"), _calculate_local_checksum(src_path, size)

return None


def _upload_or_reuse_file(ctx: WorkerContext, size: int, src_path: str, dest_bucket: str, dest_path: str):
result = _reuse_remote_file(ctx, size, src_path, dest_bucket, dest_path)
if result is not None:
dest_version_id, checksum = result
ctx.progress(size)
ctx.done(PhysicalKey(dest_bucket, dest_path, dest_version_id), checksum)
return # Optimization succeeded.
# If the optimization didn't happen, do the normal upload.
_upload_file(ctx, size, src_path, dest_bucket, dest_path)

Expand Down Expand Up @@ -648,7 +671,7 @@
else:
if dest.version_id:
raise ValueError("Cannot set VersionId on destination")
_upload_or_copy_file(ctx, size, src.path, dest.bucket, dest.path)
_upload_or_reuse_file(ctx, size, src.path, dest.bucket, dest.path)
else:
if dest.is_local():
_download_file(ctx, size, src.bucket, src.path, src.version_id, dest.path)
Expand Down Expand Up @@ -701,7 +724,7 @@
"""
size = pathlib.Path(file_path).stat().st_size
with open(file_path, 'rb') as fd:
if size < CHECKSUM_MULTIPART_THRESHOLD:
if not is_mpu(size):
contents = fd.read()
etag = hashlib.md5(contents).hexdigest()
else:
Expand Down Expand Up @@ -970,6 +993,28 @@
return wrapper


def _calculate_local_part_checksum(src: str, offset: int, length: int, callback=None) -> bytes:
hash_obj = hashlib.sha256()
bytes_remaining = length
with open(src, "rb") as fd:
fd.seek(offset)
while bytes_remaining > 0:
chunk = fd.read(min(s3_transfer_config.io_chunksize, bytes_remaining))
if not chunk:
# Should not happen, but let's not get stuck in an infinite loop.
raise QuiltException("Unexpected end of file")

Check warning on line 1005 in api/python/quilt3/data_transfer.py

View check run for this annotation

Codecov / codecov/patch/informational

api/python/quilt3/data_transfer.py#L1005

Added line #L1005 was not covered by tests
hash_obj.update(chunk)
if callback is not None:
callback(len(chunk))
bytes_remaining -= len(chunk)

return hash_obj.digest()


def _make_checksum_from_parts(parts: List[bytes]) -> str:
return binascii.b2a_base64(hashlib.sha256(b"".join(parts)).digest(), newline=False).decode()


@retry(stop=stop_after_attempt(MAX_FIX_HASH_RETRIES),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_result(lambda results: any(r is None or isinstance(r, Exception) for r in results)),
Expand All @@ -990,21 +1035,10 @@
progress_update = with_lock(progress.update)

def _process_url_part(src: PhysicalKey, offset: int, length: int):
hash_obj = hashlib.sha256()

if src.is_local():
bytes_remaining = length
with open(src.path, 'rb') as fd:
fd.seek(offset)
while bytes_remaining > 0:
chunk = fd.read(min(s3_transfer_config.io_chunksize, bytes_remaining))
if not chunk:
# Should not happen, but let's not get stuck in an infinite loop.
raise QuiltException("Unexpected end of file")
hash_obj.update(chunk)
progress_update(len(chunk))
bytes_remaining -= len(chunk)
return _calculate_local_part_checksum(src.path, offset, length, progress_update)
else:
hash_obj = hashlib.sha256()
end = offset + length - 1
params = dict(
Bucket=src.bucket,
Expand All @@ -1026,7 +1060,7 @@
except (ConnectionError, HTTPClientError, ReadTimeoutError) as ex:
return ex

return hash_obj.digest()
return hash_obj.digest()

futures: List[Tuple[int, List[Future]]] = []

Expand All @@ -1046,11 +1080,7 @@
for idx, future_list in futures:
future_results = [future.result() for future in future_list]
exceptions = [ex for ex in future_results if isinstance(ex, Exception)]
if exceptions:
results[idx] = exceptions[0]
else:
hashes_hash = hashlib.sha256(b''.join(future_results)).digest()
results[idx] = binascii.b2a_base64(hashes_hash, newline=False).decode()
results[idx] = exceptions[0] if exceptions else _make_checksum_from_parts(future_results)
finally:
stopped = True
for _, future_list in futures:
Expand Down
4 changes: 2 additions & 2 deletions api/python/quilt3/packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,9 +1540,9 @@ def check_hash_conficts(latest_hash):
new_entry.hash = dict(type=SHA256_CHUNKED_HASH_NAME, value=checksum)
pkg._set(logical_key, new_entry)

# Needed if the files already exist in S3, but were uploaded without ChecksumAlgorithm='SHA256'.
# Some entries may miss hash values (e.g because of selector_fn), so we need
# to fix them before calculating the top hash.
pkg._fix_sha256()

top_hash = pkg._calculate_top_hash(pkg._meta, pkg.walk())

if dedupe and top_hash == latest_hash:
Expand Down
6 changes: 4 additions & 2 deletions api/python/tests/integration/test_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1914,10 +1914,11 @@ def test_push_selector_fn_false(self):
selector_fn = mock.MagicMock(return_value=False)
push_manifest_mock = self.patch_s3_registry('push_manifest')
self.patch_s3_registry('shorten_top_hash', return_value='7a67ff4')
with patch('quilt3.packages.calculate_checksum', return_value=[('SHA256', "a" * 64)]):
with patch('quilt3.packages.calculate_checksum', return_value=["a" * 64]) as calculate_checksum_mock:
pkg.push(pkg_name, registry=f's3://{dst_bucket}', selector_fn=selector_fn, force=True)

selector_fn.assert_called_once_with(lk, pkg[lk])
calculate_checksum_mock.assert_called_once_with([PhysicalKey(src_bucket, src_key, src_version)], [0])
push_manifest_mock.assert_called_once_with(pkg_name, mock.sentinel.top_hash, ANY)
assert Package.load(
BytesIO(push_manifest_mock.call_args[0][2])
Expand Down Expand Up @@ -1960,10 +1961,11 @@ def test_push_selector_fn_true(self):
)
push_manifest_mock = self.patch_s3_registry('push_manifest')
self.patch_s3_registry('shorten_top_hash', return_value='7a67ff4')
with patch('quilt3.packages.calculate_checksum', return_value=["a" * 64]):
with patch('quilt3.packages.calculate_checksum', return_value=[]) as calculate_checksum_mock:
pkg.push(pkg_name, registry=f's3://{dst_bucket}', selector_fn=selector_fn, force=True)

selector_fn.assert_called_once_with(lk, pkg[lk])
calculate_checksum_mock.assert_called_once_with([], [])
push_manifest_mock.assert_called_once_with(pkg_name, mock.sentinel.top_hash, ANY)
assert Package.load(
BytesIO(push_manifest_mock.call_args[0][2])
Expand Down
Loading