From 09d6cd5acb921d89bc8f7f2c3babfa1acf6be1e9 Mon Sep 17 00:00:00 2001 From: Sergey Fedoseev Date: Thu, 17 Sep 2020 20:18:42 +0500 Subject: [PATCH] Get rid of asyncio --- api/python/quilt3/data_transfer.py | 219 ++++++++++++------------- api/python/quilt3/util.py | 70 -------- api/python/setup.py | 1 - api/python/tests/test_data_transfer.py | 1 - 4 files changed, 105 insertions(+), 186 deletions(-) diff --git a/api/python/quilt3/data_transfer.py b/api/python/quilt3/data_transfer.py index f7ebc466fd6..7ca507c31b4 100644 --- a/api/python/quilt3/data_transfer.py +++ b/api/python/quilt3/data_transfer.py @@ -1,12 +1,16 @@ -import asyncio +import itertools import logging import math import os +import queue import stat +import threading +import types from collections import defaultdict, deque from codecs import iterdecode from concurrent.futures import ThreadPoolExecutor from enum import Enum +import concurrent import functools import hashlib import pathlib @@ -20,7 +24,6 @@ from botocore.exceptions import ClientError, ConnectionError, HTTPClientError, ReadTimeoutError import boto3 from boto3.s3.transfer import TransferConfig -from quilt3 import util from s3transfer.utils import ChunksizeAdjuster, OSUtils, signal_transferring, signal_not_transferring import jsonlines @@ -217,6 +220,11 @@ def check_head_object_works_for_client(s3_client, params): s3_transfer_config = TransferConfig() + +def read_file_chunks(file, chunksize=s3_transfer_config.io_chunksize): + return itertools.takewhile(bool, map(file.read, itertools.repeat(chunksize))) + + # When uploading files at least this size, compare the ETags first and skip the upload if they're equal; # copy the remote file onto itself if the metadata changes. UPLOAD_ETAG_OPTIMIZATION_THRESHOLD = 1024 @@ -614,10 +622,7 @@ def _calculate_etag(file_path): chunksize = adjuster.adjust_chunksize(s3_transfer_config.multipart_chunksize, size) hashes = [] - while True: - contents = fd.read(chunksize) - if not contents: - break + for contents in read_file_chunks(fd, chunksize): hashes.append(hashlib.md5(contents).digest()) etag = '%s-%d' % (hashlib.md5(b''.join(hashes)).hexdigest(), len(hashes)) return '"%s"' % etag @@ -861,9 +866,7 @@ def calculate_sha256(src_list: List[PhysicalKey], sizes: List[int]): return _calculate_sha256_internal(src_list, sizes, [None] * len(src_list)) -async def _calculate_hash_s3_get_parts(s3_client_provider, s3_executor, s3_pending_parts_semaphore, src, size): - loop = asyncio._get_running_loop() - +def _calculate_hash_get_s3_chunks(ctx, src, size): params = dict(Bucket=src.bucket, Key=src.path) if src.version_id is not None: params.update(VersionId=src.version_id) @@ -878,57 +881,50 @@ async def _calculate_hash_s3_get_parts(s3_client_provider, s3_executor, s3_pendi (None,) ) - s3_client = s3_client_provider.find_correct_client(S3Api.GET_OBJECT, src.bucket, params) + def read_to_queue(part_number, part_chunks_q): + try: + ctx.pending_parts_semaphore.acquire() + logger.debug('%r part %s: download started', src, part_number) + if part_number is not None: + start = part_number * part_size + end = min(start + part_size, size) - 1 + part_params = dict(params, Range=f'bytes={start}-{end}') + else: + part_params = params - def read_to_queue(part_number, q): - logger.debug('%r part %s: download started', src, part_number) - if part_number is not None: - start = part_number * part_size - end = min(start + part_size, size) - 1 - part_params = dict(params, Range=f'bytes={start}-{end}') - else: - part_params = params - - body = s3_client.get_object(**part_params)['Body'] - while True: - chunk = body.read(s3_transfer_config.io_chunksize) - if not chunk: - break - loop.call_soon_threadsafe(q.put_nowait, chunk) - logger.debug('%r part %s: downloaded', src, part_number) - - async def get_part(part_number, q): - await s3_pending_parts_semaphore.acquire() + body = s3_client.get_object(**part_params)['Body'] + for chunk in read_file_chunks(body): + part_chunks_q.put_nowait(chunk) + + logger.debug('%r part %s: downloaded', src, part_number) + finally: + part_chunks_q.put_nowait(None) + + s3_client = ctx.find_correct_client(S3Api.GET_OBJECT, src.bucket, params) + part_futures = deque() + for part_number in part_numbers: + part_chunks_q = queue.Queue() + future = ctx.executor.submit(read_to_queue, part_number, part_chunks_q) + part_futures.append((part_number, future, part_chunks_q)) + + while part_futures: + part_number, future, part_chunks_q = part_futures.popleft() try: - logger.debug('%r part %s: download enqueued', src, part_number) - await loop.run_in_executor(s3_executor, read_to_queue, part_number, q) - except Exception: - s3_pending_parts_semaphore.release() - raise + yield from iter(part_chunks_q.get, None) + future.result() + logger.debug('%r part %s: processed', src, part_number) + finally: + ctx.pending_parts_semaphore.release() - futures = deque() - for i in part_numbers: - q = asyncio.Queue() - futures.append((i, asyncio.ensure_future(get_part(i, q)), q)) - - # This gives a chance to other 'hash file' tasks schedule 'download part' tasks. - # It's needed because we want to process as many files as possible at once, this allows - # to process them sequentially, so we don't have to store the pending parts. - await asyncio.sleep(0) - - while futures: - i, part_downloader, q = futures.popleft() - while True: - get_new_chunk = asyncio.ensure_future(q.get()) - await asyncio.wait({part_downloader, get_new_chunk}, return_when=asyncio.FIRST_COMPLETED) - if get_new_chunk.done(): - yield get_new_chunk.result() - else: - get_new_chunk.cancel() - part_downloader.result() - break - s3_pending_parts_semaphore.release() - logger.debug('%r part %s: processed', src, i) + +def with_lock(f): + lock = threading.Lock() + + @functools.wraps(f) + def wrapper(*args, **kwds): + with lock: + return f(*args, **kwds) + return wrapper @retry(stop=stop_after_attempt(MAX_FIX_HASH_RETRIES), @@ -942,69 +938,64 @@ def _calculate_sha256_internal(src_list, sizes, results): for size, result in zip(sizes, results) if result is None or isinstance(result, Exception) ) - s3_client_provider = S3ClientProvider() # This controls how many parts can be stored in the memory. # This includes the ones that are being downloaded or hashed. + # The number was chosen empirically. s3_max_pending_parts = s3_transfer_config.max_request_concurrency * 4 - async def main(): - loop = asyncio._get_running_loop() - - s3_pending_parts_semaphore = asyncio.BoundedSemaphore(s3_max_pending_parts) - with tqdm(desc="Hashing", total=total_size, unit='B', unit_scale=True, disable=DISABLE_TQDM) as progress, \ - ThreadPoolExecutor( - s3_transfer_config.max_request_concurrency, - thread_name_prefix='s3-executor', - ) as s3_executor: - def _process_local(src, size): - hash_obj = hashlib.sha256() - with open(src.path, 'rb') as fd: - while True: - chunk = fd.read(64 * 1024) - if not chunk: - break - hash_obj.update(chunk) - loop.call_soon_threadsafe(progress.update, len(chunk)) - - current_file_size = fd.tell() - if current_file_size != size: - warnings.warn( - f"Expected the package entry at {src!r} to be {size} B in size, but " - f"found an object which is {current_file_size} B instead. This " - f"indicates that the content of the file changed in between when you " - f"included this entry in the package (via set or set_dir) and now. " - f"This should be avoided if possible." - ) - return hash_obj.hexdigest() - - async def _process_s3(src, size): - hash_obj = hashlib.sha256() - try: - async for chunk in _calculate_hash_s3_get_parts( - s3_client_provider, s3_executor, s3_pending_parts_semaphore, src, size - ): - await loop.run_in_executor(None, hash_obj.update, chunk) - progress.update(len(chunk)) - except (ConnectionError, HTTPClientError, ReadTimeoutError) as e: - return e - - return hash_obj.hexdigest() - - def _process_url(src, size): - if src.is_local(): - return loop.run_in_executor(None, _process_local, src, size) - else: - return _process_s3(src, size) + def get_file_chunks(src, size): + with open(src.path, 'rb') as file: + yield from read_file_chunks(file) + + current_file_size = file.tell() + if current_file_size != size: + warnings.warn( + f"Expected the package entry at {src!r} to be {size} B in size, but " + f"found an object which is {current_file_size} B instead. This " + f"indicates that the content of the file changed in between when you " + f"included this entry in the package (via set or set_dir) and now. " + f"This should be avoided if possible." + ) - fs = [ - (i, asyncio.ensure_future(_process_url(src, size))) - for i, (src, size, result) in enumerate(zip(src_list, sizes, results)) - if result is None or isinstance(result, Exception) - ] - for i, f in fs: - results[i] = await f + def _process_url(src, size): + hash_obj = hashlib.sha256() + + generator, exceptions_to_retry = ( + (get_file_chunks(src, size), ()) + if src.is_local() else + ( + _calculate_hash_get_s3_chunks(s3_context, src, size), + (ConnectionError, HTTPClientError, ReadTimeoutError) + ) + ) + try: + for chunk in generator: + hash_obj.update(chunk) + progress_update(len(chunk)) + except exceptions_to_retry as e: + return e + return hash_obj.hexdigest() + + with tqdm(desc="Hashing", total=total_size, unit='B', unit_scale=True, disable=DISABLE_TQDM) as progress, \ + ThreadPoolExecutor() as executor, \ + ThreadPoolExecutor( + s3_transfer_config.max_request_concurrency, + thread_name_prefix='s3-executor', + ) as s3_executor: + s3_context = types.SimpleNamespace( + find_correct_client=with_lock(S3ClientProvider().find_correct_client), + pending_parts_semaphore=threading.BoundedSemaphore(s3_max_pending_parts), + executor=s3_executor, + ) + progress_update = with_lock(progress.update) + future_to_idx = { + executor.submit(_process_url, src, size): i + for i, (src, size, result) in enumerate(zip(src_list, sizes, results)) + if result is None or isinstance(result, Exception) + } + for future in concurrent.futures.as_completed(future_to_idx): + results[future_to_idx[future]] = future.result() - util.asyncio_run(main()) return results diff --git a/api/python/quilt3/util.py b/api/python/quilt3/util.py index 8d9b63e8748..a4e4f226abc 100644 --- a/api/python/quilt3/util.py +++ b/api/python/quilt3/util.py @@ -1,5 +1,4 @@ import re -import sys from collections import OrderedDict import datetime import json @@ -537,72 +536,3 @@ def catalog_package_url(catalog_url, bucket, package_name, package_timestamp="la if tree: package_url = package_url + f"/tree/{package_timestamp}" return package_url - - -if sys.version_info < (3, 7): - # Copy-pasted from - # https://github.com/python/cpython/blob/457d4e97de0369bc786e363cb53c7ef3276fdfcd/Lib/asyncio/runners.py#L8-L74 - def asyncio_run(main, *, debug=None): - """Execute the coroutine and return the result. - This function runs the passed coroutine, taking care of - managing the asyncio event loop and finalizing asynchronous - generators. - This function cannot be called when another asyncio event loop is - running in the same thread. - If debug is True, the event loop will be run in debug mode. - This function always creates a new event loop and closes it at the end. - It should be used as a main entry point for asyncio programs, and should - ideally only be called once. - Example: - async def main(): - await asyncio.sleep(1) - print('hello') - asyncio.run(main()) - """ - from asyncio import coroutines - from asyncio import events - from asyncio import tasks - - def _cancel_all_tasks(loop): - to_cancel = tasks.Task.all_tasks(loop) - if not to_cancel: - return - - for task in to_cancel: - task.cancel() - - loop.run_until_complete( - tasks.gather(*to_cancel, loop=loop, return_exceptions=True)) - - for task in to_cancel: - if task.cancelled(): - continue - if task.exception() is not None: - loop.call_exception_handler({ - 'message': 'unhandled exception during asyncio.run() shutdown', - 'exception': task.exception(), - 'task': task, - }) - - if events._get_running_loop() is not None: - raise RuntimeError( - "asyncio.run() cannot be called from a running event loop") - - if not coroutines.iscoroutine(main): - raise ValueError("a coroutine was expected, got {!r}".format(main)) - - loop = events.new_event_loop() - try: - events.set_event_loop(loop) - if debug is not None: - loop.set_debug(debug) - return loop.run_until_complete(main) - finally: - try: - _cancel_all_tasks(loop) - loop.run_until_complete(loop.shutdown_asyncgens()) - finally: - events.set_event_loop(None) - loop.close() -else: - from asyncio import run as asyncio_run # pylint: disable=unused-import diff --git a/api/python/setup.py b/api/python/setup.py index ecbe391fb54..b3667b00bd0 100644 --- a/api/python/setup.py +++ b/api/python/setup.py @@ -85,7 +85,6 @@ def run(self): 'pytest<5.1.0', # TODO: Fix pytest.ensuretemp in conftest.py 'pytest-cov', 'pytest-env', - 'pytest-timeout', 'responses', 'tox', 'detox', diff --git a/api/python/tests/test_data_transfer.py b/api/python/tests/test_data_transfer.py index 0bc83f8902d..c0056f49d3c 100644 --- a/api/python/tests/test_data_transfer.py +++ b/api/python/tests/test_data_transfer.py @@ -588,7 +588,6 @@ def test_single_request(self): with self.subTest(threshold=threshold, chunksize=chunksize): self._hashing_subtest(threshold=threshold, chunksize=chunksize) - @pytest.mark.timeout(5) def test_multi_request(self): params = ( (