Skip to content

Commit

Permalink
Multi-thread hashing of large S3 files
Browse files Browse the repository at this point in the history
  • Loading branch information
sir-sigurd committed Sep 7, 2020
1 parent e20a922 commit 34c38d6
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 27 deletions.
96 changes: 69 additions & 27 deletions api/python/quilt3/data_transfer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import math
import os
import stat
Expand All @@ -19,6 +20,7 @@
from botocore.exceptions import ClientError, ConnectionError, HTTPClientError, ReadTimeoutError
import boto3
from boto3.s3.transfer import TransferConfig
from quilt3 import util
from s3transfer.utils import ChunksizeAdjuster, OSUtils, signal_transferring, signal_not_transferring

import jsonlines
Expand Down Expand Up @@ -856,6 +858,46 @@ def calculate_sha256(src_list: List[PhysicalKey], sizes: List[int]):
return _calculate_sha256_internal(src_list, sizes, [None] * len(src_list))


async def _calculate_hash_s3_get_parts(src, size):
loop = asyncio._get_running_loop()

params = dict(Bucket=src.bucket, Key=src.path)
if src.version_id is not None:
params.update(VersionId=src.version_id)
part_size = s3_transfer_config.multipart_chunksize
is_multi_part = (
size >= s3_transfer_config.multipart_threshold
and size > part_size
)
part_numbers = (
range(math.ceil(size / part_size))
if is_multi_part else
(None,)
)

s3_client = await loop.run_in_executor(
None,
S3ClientProvider().find_correct_client, S3Api.GET_OBJECT, src.bucket, params,
)

async def get_part(part_number):
if part_number is not None:
start = part_number * part_size
end = min(start + part_size, size) - 1
part_params = dict(params, Range=f'bytes={start}-{end}')
else:
part_params = params
body = (await loop.run_in_executor(None, lambda: s3_client.get_object(**part_params)))['Body']
return await loop.run_in_executor(None, body.read)

futures = [
asyncio.ensure_future(get_part(i))
for i in part_numbers
]
for future in futures:
yield await future


@retry(stop=stop_after_attempt(MAX_FIX_HASH_RETRIES),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_result(lambda results: any(r is None or isinstance(r, Exception) for r in results)),
Expand All @@ -867,20 +909,20 @@ def _calculate_sha256_internal(src_list, sizes, results):
for size, result in zip(sizes, results)
if result is None or isinstance(result, Exception)
)
lock = Lock()

with tqdm(desc="Hashing", total=total_size, unit='B', unit_scale=True, disable=DISABLE_TQDM) as progress:
def _process_url(src, size):
hash_obj = hashlib.sha256()
if src.is_local():
async def main():
loop = asyncio._get_running_loop()

with tqdm(desc="Hashing", total=total_size, unit='B', unit_scale=True, disable=DISABLE_TQDM) as progress:
def _process_local(src, size):
hash_obj = hashlib.sha256()
with open(src.path, 'rb') as fd:
while True:
chunk = fd.read(64 * 1024)
if not chunk:
break
hash_obj.update(chunk)
with lock:
progress.update(len(chunk))
loop.call_soon_threadsafe(progress.update, len(chunk))

current_file_size = fd.tell()
if current_file_size != size:
Expand All @@ -891,33 +933,33 @@ def _process_url(src, size):
f"included this entry in the package (via set or set_dir) and now. "
f"This should be avoided if possible."
)
return hash_obj.hexdigest()

else:
params = dict(Bucket=src.bucket, Key=src.path)
if src.version_id is not None:
params.update(VersionId=src.version_id)
async def _process_s3(src, size):
hash_obj = hashlib.sha256()
try:
s3_client = S3ClientProvider().find_correct_client(S3Api.GET_OBJECT, src.bucket, params)

resp = s3_client.get_object(**params)
body = resp['Body']
for chunk in body:
async for chunk in _calculate_hash_s3_get_parts(src, size):
hash_obj.update(chunk)
with lock:
progress.update(len(chunk))
progress.update(len(chunk))
except (ConnectionError, HTTPClientError, ReadTimeoutError) as e:
return e
return hash_obj.hexdigest()

with ThreadPoolExecutor() as executor:
future_to_idx = {
executor.submit(_process_url, src, size): i
for i, (src, size, result) in enumerate(zip(src_list, sizes, results))
if result is None or isinstance(result, Exception)
}
for future in concurrent.futures.as_completed(future_to_idx):
results[future_to_idx[future]] = future.result()
return hash_obj.hexdigest()

async def _process_url(src, size):
if src.is_local():
return await loop.run_in_executor(None, _process_local, src, size)
else:
return await _process_s3(src, size)

fs = [
asyncio.ensure_future(_process_url(src, size))
for src, size, result in zip(src_list, sizes, results)
]
for i, f in enumerate(fs):
results[i] = await f

util.asyncio_run(main())
return results


Expand Down
64 changes: 64 additions & 0 deletions api/python/quilt3/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,3 +534,67 @@ def catalog_package_url(catalog_url, bucket, package_name, package_timestamp="la
if tree:
package_url = package_url + f"/tree/{package_timestamp}"
return package_url


def asyncio_run(main, *, debug=None):
"""Execute the coroutine and return the result.
This function runs the passed coroutine, taking care of
managing the asyncio event loop and finalizing asynchronous
generators.
This function cannot be called when another asyncio event loop is
running in the same thread.
If debug is True, the event loop will be run in debug mode.
This function always creates a new event loop and closes it at the end.
It should be used as a main entry point for asyncio programs, and should
ideally only be called once.
Example:
async def main():
await asyncio.sleep(1)
print('hello')
asyncio.run(main())
"""
from asyncio import coroutines
from asyncio import events
from asyncio import tasks

def _cancel_all_tasks(loop):
to_cancel = tasks.Task.all_tasks(loop)
if not to_cancel:
return

for task in to_cancel:
task.cancel()

loop.run_until_complete(
tasks.gather(*to_cancel, loop=loop, return_exceptions=True))

for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler({
'message': 'unhandled exception during asyncio.run() shutdown',
'exception': task.exception(),
'task': task,
})

if events._get_running_loop() is not None:
raise RuntimeError(
"asyncio.run() cannot be called from a running event loop")

if not coroutines.iscoroutine(main):
raise ValueError("a coroutine was expected, got {!r}".format(main))

loop = events.new_event_loop()
try:
events.set_event_loop(loop)
if debug is not None:
loop.set_debug(debug)
return loop.run_until_complete(main)
finally:
try:
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
events.set_event_loop(None)
loop.close()

0 comments on commit 34c38d6

Please sign in to comment.