Skip to content

Commit

Permalink
Get rid of asyncio
Browse files Browse the repository at this point in the history
  • Loading branch information
sir-sigurd committed Sep 18, 2020
1 parent 550ed9a commit 09d6cd5
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 186 deletions.
219 changes: 105 additions & 114 deletions api/python/quilt3/data_transfer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import asyncio
import itertools
import logging
import math
import os
import queue
import stat
import threading
import types
from collections import defaultdict, deque
from codecs import iterdecode
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
import concurrent
import functools
import hashlib
import pathlib
Expand All @@ -20,7 +24,6 @@
from botocore.exceptions import ClientError, ConnectionError, HTTPClientError, ReadTimeoutError
import boto3
from boto3.s3.transfer import TransferConfig
from quilt3 import util
from s3transfer.utils import ChunksizeAdjuster, OSUtils, signal_transferring, signal_not_transferring

import jsonlines
Expand Down Expand Up @@ -217,6 +220,11 @@ def check_head_object_works_for_client(s3_client, params):

s3_transfer_config = TransferConfig()


def read_file_chunks(file, chunksize=s3_transfer_config.io_chunksize):
return itertools.takewhile(bool, map(file.read, itertools.repeat(chunksize)))


# When uploading files at least this size, compare the ETags first and skip the upload if they're equal;
# copy the remote file onto itself if the metadata changes.
UPLOAD_ETAG_OPTIMIZATION_THRESHOLD = 1024
Expand Down Expand Up @@ -614,10 +622,7 @@ def _calculate_etag(file_path):
chunksize = adjuster.adjust_chunksize(s3_transfer_config.multipart_chunksize, size)

hashes = []
while True:
contents = fd.read(chunksize)
if not contents:
break
for contents in read_file_chunks(fd, chunksize):
hashes.append(hashlib.md5(contents).digest())
etag = '%s-%d' % (hashlib.md5(b''.join(hashes)).hexdigest(), len(hashes))
return '"%s"' % etag
Expand Down Expand Up @@ -861,9 +866,7 @@ def calculate_sha256(src_list: List[PhysicalKey], sizes: List[int]):
return _calculate_sha256_internal(src_list, sizes, [None] * len(src_list))


async def _calculate_hash_s3_get_parts(s3_client_provider, s3_executor, s3_pending_parts_semaphore, src, size):
loop = asyncio._get_running_loop()

def _calculate_hash_get_s3_chunks(ctx, src, size):
params = dict(Bucket=src.bucket, Key=src.path)
if src.version_id is not None:
params.update(VersionId=src.version_id)
Expand All @@ -878,57 +881,50 @@ async def _calculate_hash_s3_get_parts(s3_client_provider, s3_executor, s3_pendi
(None,)
)

s3_client = s3_client_provider.find_correct_client(S3Api.GET_OBJECT, src.bucket, params)
def read_to_queue(part_number, part_chunks_q):
try:
ctx.pending_parts_semaphore.acquire()
logger.debug('%r part %s: download started', src, part_number)
if part_number is not None:
start = part_number * part_size
end = min(start + part_size, size) - 1
part_params = dict(params, Range=f'bytes={start}-{end}')
else:
part_params = params

def read_to_queue(part_number, q):
logger.debug('%r part %s: download started', src, part_number)
if part_number is not None:
start = part_number * part_size
end = min(start + part_size, size) - 1
part_params = dict(params, Range=f'bytes={start}-{end}')
else:
part_params = params

body = s3_client.get_object(**part_params)['Body']
while True:
chunk = body.read(s3_transfer_config.io_chunksize)
if not chunk:
break
loop.call_soon_threadsafe(q.put_nowait, chunk)
logger.debug('%r part %s: downloaded', src, part_number)

async def get_part(part_number, q):
await s3_pending_parts_semaphore.acquire()
body = s3_client.get_object(**part_params)['Body']
for chunk in read_file_chunks(body):
part_chunks_q.put_nowait(chunk)

logger.debug('%r part %s: downloaded', src, part_number)
finally:
part_chunks_q.put_nowait(None)

s3_client = ctx.find_correct_client(S3Api.GET_OBJECT, src.bucket, params)
part_futures = deque()
for part_number in part_numbers:
part_chunks_q = queue.Queue()
future = ctx.executor.submit(read_to_queue, part_number, part_chunks_q)
part_futures.append((part_number, future, part_chunks_q))

while part_futures:
part_number, future, part_chunks_q = part_futures.popleft()
try:
logger.debug('%r part %s: download enqueued', src, part_number)
await loop.run_in_executor(s3_executor, read_to_queue, part_number, q)
except Exception:
s3_pending_parts_semaphore.release()
raise
yield from iter(part_chunks_q.get, None)
future.result()
logger.debug('%r part %s: processed', src, part_number)
finally:
ctx.pending_parts_semaphore.release()

futures = deque()
for i in part_numbers:
q = asyncio.Queue()
futures.append((i, asyncio.ensure_future(get_part(i, q)), q))

# This gives a chance to other 'hash file' tasks schedule 'download part' tasks.
# It's needed because we want to process as many files as possible at once, this allows
# to process them sequentially, so we don't have to store the pending parts.
await asyncio.sleep(0)

while futures:
i, part_downloader, q = futures.popleft()
while True:
get_new_chunk = asyncio.ensure_future(q.get())
await asyncio.wait({part_downloader, get_new_chunk}, return_when=asyncio.FIRST_COMPLETED)
if get_new_chunk.done():
yield get_new_chunk.result()
else:
get_new_chunk.cancel()
part_downloader.result()
break
s3_pending_parts_semaphore.release()
logger.debug('%r part %s: processed', src, i)

def with_lock(f):
lock = threading.Lock()

@functools.wraps(f)
def wrapper(*args, **kwds):
with lock:
return f(*args, **kwds)
return wrapper


@retry(stop=stop_after_attempt(MAX_FIX_HASH_RETRIES),
Expand All @@ -942,69 +938,64 @@ def _calculate_sha256_internal(src_list, sizes, results):
for size, result in zip(sizes, results)
if result is None or isinstance(result, Exception)
)
s3_client_provider = S3ClientProvider()
# This controls how many parts can be stored in the memory.
# This includes the ones that are being downloaded or hashed.
# The number was chosen empirically.
s3_max_pending_parts = s3_transfer_config.max_request_concurrency * 4

async def main():
loop = asyncio._get_running_loop()

s3_pending_parts_semaphore = asyncio.BoundedSemaphore(s3_max_pending_parts)
with tqdm(desc="Hashing", total=total_size, unit='B', unit_scale=True, disable=DISABLE_TQDM) as progress, \
ThreadPoolExecutor(
s3_transfer_config.max_request_concurrency,
thread_name_prefix='s3-executor',
) as s3_executor:
def _process_local(src, size):
hash_obj = hashlib.sha256()
with open(src.path, 'rb') as fd:
while True:
chunk = fd.read(64 * 1024)
if not chunk:
break
hash_obj.update(chunk)
loop.call_soon_threadsafe(progress.update, len(chunk))

current_file_size = fd.tell()
if current_file_size != size:
warnings.warn(
f"Expected the package entry at {src!r} to be {size} B in size, but "
f"found an object which is {current_file_size} B instead. This "
f"indicates that the content of the file changed in between when you "
f"included this entry in the package (via set or set_dir) and now. "
f"This should be avoided if possible."
)
return hash_obj.hexdigest()

async def _process_s3(src, size):
hash_obj = hashlib.sha256()
try:
async for chunk in _calculate_hash_s3_get_parts(
s3_client_provider, s3_executor, s3_pending_parts_semaphore, src, size
):
await loop.run_in_executor(None, hash_obj.update, chunk)
progress.update(len(chunk))
except (ConnectionError, HTTPClientError, ReadTimeoutError) as e:
return e

return hash_obj.hexdigest()

def _process_url(src, size):
if src.is_local():
return loop.run_in_executor(None, _process_local, src, size)
else:
return _process_s3(src, size)
def get_file_chunks(src, size):
with open(src.path, 'rb') as file:
yield from read_file_chunks(file)

current_file_size = file.tell()
if current_file_size != size:
warnings.warn(
f"Expected the package entry at {src!r} to be {size} B in size, but "
f"found an object which is {current_file_size} B instead. This "
f"indicates that the content of the file changed in between when you "
f"included this entry in the package (via set or set_dir) and now. "
f"This should be avoided if possible."
)

fs = [
(i, asyncio.ensure_future(_process_url(src, size)))
for i, (src, size, result) in enumerate(zip(src_list, sizes, results))
if result is None or isinstance(result, Exception)
]
for i, f in fs:
results[i] = await f
def _process_url(src, size):
hash_obj = hashlib.sha256()

generator, exceptions_to_retry = (
(get_file_chunks(src, size), ())
if src.is_local() else
(
_calculate_hash_get_s3_chunks(s3_context, src, size),
(ConnectionError, HTTPClientError, ReadTimeoutError)
)
)
try:
for chunk in generator:
hash_obj.update(chunk)
progress_update(len(chunk))
except exceptions_to_retry as e:
return e
return hash_obj.hexdigest()

with tqdm(desc="Hashing", total=total_size, unit='B', unit_scale=True, disable=DISABLE_TQDM) as progress, \
ThreadPoolExecutor() as executor, \
ThreadPoolExecutor(
s3_transfer_config.max_request_concurrency,
thread_name_prefix='s3-executor',
) as s3_executor:
s3_context = types.SimpleNamespace(
find_correct_client=with_lock(S3ClientProvider().find_correct_client),
pending_parts_semaphore=threading.BoundedSemaphore(s3_max_pending_parts),
executor=s3_executor,
)
progress_update = with_lock(progress.update)
future_to_idx = {
executor.submit(_process_url, src, size): i
for i, (src, size, result) in enumerate(zip(src_list, sizes, results))
if result is None or isinstance(result, Exception)
}
for future in concurrent.futures.as_completed(future_to_idx):
results[future_to_idx[future]] = future.result()

util.asyncio_run(main())
return results


Expand Down
70 changes: 0 additions & 70 deletions api/python/quilt3/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
import sys
from collections import OrderedDict
import datetime
import json
Expand Down Expand Up @@ -537,72 +536,3 @@ def catalog_package_url(catalog_url, bucket, package_name, package_timestamp="la
if tree:
package_url = package_url + f"/tree/{package_timestamp}"
return package_url


if sys.version_info < (3, 7):
# Copy-pasted from
# https://github.com/python/cpython/blob/457d4e97de0369bc786e363cb53c7ef3276fdfcd/Lib/asyncio/runners.py#L8-L74
def asyncio_run(main, *, debug=None):
"""Execute the coroutine and return the result.
This function runs the passed coroutine, taking care of
managing the asyncio event loop and finalizing asynchronous
generators.
This function cannot be called when another asyncio event loop is
running in the same thread.
If debug is True, the event loop will be run in debug mode.
This function always creates a new event loop and closes it at the end.
It should be used as a main entry point for asyncio programs, and should
ideally only be called once.
Example:
async def main():
await asyncio.sleep(1)
print('hello')
asyncio.run(main())
"""
from asyncio import coroutines
from asyncio import events
from asyncio import tasks

def _cancel_all_tasks(loop):
to_cancel = tasks.Task.all_tasks(loop)
if not to_cancel:
return

for task in to_cancel:
task.cancel()

loop.run_until_complete(
tasks.gather(*to_cancel, loop=loop, return_exceptions=True))

for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler({
'message': 'unhandled exception during asyncio.run() shutdown',
'exception': task.exception(),
'task': task,
})

if events._get_running_loop() is not None:
raise RuntimeError(
"asyncio.run() cannot be called from a running event loop")

if not coroutines.iscoroutine(main):
raise ValueError("a coroutine was expected, got {!r}".format(main))

loop = events.new_event_loop()
try:
events.set_event_loop(loop)
if debug is not None:
loop.set_debug(debug)
return loop.run_until_complete(main)
finally:
try:
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
events.set_event_loop(None)
loop.close()
else:
from asyncio import run as asyncio_run # pylint: disable=unused-import
1 change: 0 additions & 1 deletion api/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def run(self):
'pytest<5.1.0', # TODO: Fix pytest.ensuretemp in conftest.py
'pytest-cov',
'pytest-env',
'pytest-timeout',
'responses',
'tox',
'detox',
Expand Down
1 change: 0 additions & 1 deletion api/python/tests/test_data_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,6 @@ def test_single_request(self):
with self.subTest(threshold=threshold, chunksize=chunksize):
self._hashing_subtest(threshold=threshold, chunksize=chunksize)

@pytest.mark.timeout(5)
def test_multi_request(self):
params = (
(
Expand Down

0 comments on commit 09d6cd5

Please sign in to comment.