From 1056a19b98309e48e80e0180c025e5d597b859d6 Mon Sep 17 00:00:00 2001 From: Sergey Fedoseev Date: Mon, 7 Sep 2020 21:19:44 +0500 Subject: [PATCH] Multi-thread hashing of large S3 files --- api/python/quilt3/data_transfer.py | 199 +++++++++++++++++++------ api/python/tests/test_data_transfer.py | 86 +++++++---- api/python/tests/utils.py | 51 +++++++ 3 files changed, 263 insertions(+), 73 deletions(-) diff --git a/api/python/quilt3/data_transfer.py b/api/python/quilt3/data_transfer.py index 68d96957b45..9cbd932ea40 100644 --- a/api/python/quilt3/data_transfer.py +++ b/api/python/quilt3/data_transfer.py @@ -1,7 +1,11 @@ import itertools +import logging import math import os +import queue import stat +import threading +import types from collections import defaultdict, deque from codecs import iterdecode from concurrent.futures import ThreadPoolExecutor @@ -33,6 +37,9 @@ MAX_FIX_HASH_RETRIES = 3 +logger = logging.getLogger(__name__) + + class S3Api(Enum): GET_OBJECT = "GET_OBJECT" HEAD_OBJECT = "HEAD_OBJECT" @@ -859,6 +866,88 @@ def calculate_sha256(src_list: List[PhysicalKey], sizes: List[int]): return _calculate_sha256_internal(src_list, sizes, [None] * len(src_list)) +def _calculate_hash_get_s3_chunks(ctx, src, size): + params = dict(Bucket=src.bucket, Key=src.path) + if src.version_id is not None: + params.update(VersionId=src.version_id) + part_size = s3_transfer_config.multipart_chunksize + is_multi_part = ( + size >= s3_transfer_config.multipart_threshold + and size > part_size + ) + part_numbers = ( + range(math.ceil(size / part_size)) + if is_multi_part else + (None,) + ) + s3_client = ctx.find_correct_client(S3Api.GET_OBJECT, src.bucket, params) + + def read_to_queue(part_number, put_to_queue, stopped_event): + try: + logger.debug('%r part %s: download enqueued', src, part_number) + # This semaphore is released in iter_queue() when the part is fully + # downloaded and all chunks are retrieved from the queue or if download + # fails. + ctx.pending_parts_semaphore.acquire() + logger.debug('%r part %s: download started', src, part_number) + if part_number is not None: + start = part_number * part_size + end = min(start + part_size, size) - 1 + part_params = dict(params, Range=f'bytes={start}-{end}') + else: + part_params = params + + body = s3_client.get_object(**part_params)['Body'] + for chunk in read_file_chunks(body): + put_to_queue(chunk) + if stopped_event.is_set(): + logger.debug('%r part %s: stopped', src, part_number) + break + + logger.debug('%r part %s: downloaded', src, part_number) + finally: + put_to_queue(None) + + def iter_queue(part_number): + q = queue.Queue() + stopped_event = threading.Event() + f = ctx.executor.submit(read_to_queue, part_number, q.put_nowait, stopped_event) + try: + yield + yield from iter(q.get, None) + f.result() + logger.debug('%r part %s: processed', src, part_number) + except GeneratorExit: + if f.cancel(): + logger.debug('%r part %s: cancelled', src, part_number) + else: + stopped_event.set() + finally: + if not f.cancelled(): + ctx.pending_parts_semaphore.release() + logger.debug('%r part %s: semaphore released', src, part_number) + + generators = deque() + for gen in map(iter_queue, part_numbers): + # Step into generator, so it will receive GeneratorExit when it's closed + # or garbage collected. + next(gen) + generators.append(gen) + + return itertools.chain.from_iterable( + itertools.starmap(generators.popleft, itertools.repeat((), len(part_numbers)))) + + +def with_lock(f): + lock = threading.Lock() + + @functools.wraps(f) + def wrapper(*args, **kwargs): + with lock: + return f(*args, **kwargs) + return wrapper + + @retry(stop=stop_after_attempt(MAX_FIX_HASH_RETRIES), wait=wait_exponential(multiplier=1, min=1, max=10), retry=retry_if_result(lambda results: any(r is None or isinstance(r, Exception) for r in results)), @@ -870,53 +959,77 @@ def _calculate_sha256_internal(src_list, sizes, results): for size, result in zip(sizes, results) if result is None or isinstance(result, Exception) ) - lock = Lock() + # This controls how many parts can be stored in the memory. + # This includes the ones that are being downloaded or hashed. + # The number was chosen empirically. + s3_max_pending_parts = s3_transfer_config.max_request_concurrency * 4 + stopped = False - with tqdm(desc="Hashing", total=total_size, unit='B', unit_scale=True, disable=DISABLE_TQDM) as progress: - def _process_url(src, size): - hash_obj = hashlib.sha256() - if src.is_local(): - with open(src.path, 'rb') as fd: - for chunk in read_file_chunks(fd): - hash_obj.update(chunk) - with lock: - progress.update(len(chunk)) - - current_file_size = fd.tell() - if current_file_size != size: - warnings.warn( - f"Expected the package entry at {src!r} to be {size} B in size, but " - f"found an object which is {current_file_size} B instead. This " - f"indicates that the content of the file changed in between when you " - f"included this entry in the package (via set or set_dir) and now. " - f"This should be avoided if possible." - ) + def get_file_chunks(src, size): + with open(src.path, 'rb') as file: + yield from read_file_chunks(file) + + current_file_size = file.tell() + if current_file_size != size: + warnings.warn( + f"Expected the package entry at {src!r} to be {size} B in size, but " + f"found an object which is {current_file_size} B instead. This " + f"indicates that the content of the file changed in between when you " + f"included this entry in the package (via set or set_dir) and now. " + f"This should be avoided if possible." + ) - else: - params = dict(Bucket=src.bucket, Key=src.path) - if src.version_id is not None: - params.update(VersionId=src.version_id) - try: - s3_client = S3ClientProvider().find_correct_client(S3Api.GET_OBJECT, src.bucket, params) - - resp = s3_client.get_object(**params) - body = resp['Body'] - for chunk in read_file_chunks(body): - hash_obj.update(chunk) - with lock: - progress.update(len(chunk)) - except (ConnectionError, HTTPClientError, ReadTimeoutError) as e: - return e - return hash_obj.hexdigest() + def _process_url(src, size): + hash_obj = hashlib.sha256() - with ThreadPoolExecutor() as executor: - future_to_idx = { - executor.submit(_process_url, src, size): i - for i, (src, size, result) in enumerate(zip(src_list, sizes, results)) - if result is None or isinstance(result, Exception) - } + generator, exceptions_to_retry = ( + (get_file_chunks(src, size), ()) + if src.is_local() else + ( + _calculate_hash_get_s3_chunks(s3_context, src, size), + (ConnectionError, HTTPClientError, ReadTimeoutError) + ) + ) + try: + for chunk in generator: + hash_obj.update(chunk) + progress_update(len(chunk)) + if stopped: + return + except exceptions_to_retry as e: + return e + else: + return hash_obj.hexdigest() + finally: + # We want this generator to be finished immediately, + # so it finishes its own tasks. + del generator + + with tqdm(desc="Hashing", total=total_size, unit='B', unit_scale=True, disable=DISABLE_TQDM) as progress, \ + ThreadPoolExecutor() as executor, \ + ThreadPoolExecutor( + s3_transfer_config.max_request_concurrency, + thread_name_prefix='s3-executor', + ) as s3_executor: + s3_context = types.SimpleNamespace( + find_correct_client=with_lock(S3ClientProvider().find_correct_client), + pending_parts_semaphore=threading.BoundedSemaphore(s3_max_pending_parts), + executor=s3_executor, + ) + progress_update = with_lock(progress.update) + future_to_idx = { + executor.submit(_process_url, src, size): i + for i, (src, size, result) in enumerate(zip(src_list, sizes, results)) + if result is None or isinstance(result, Exception) + } + try: for future in concurrent.futures.as_completed(future_to_idx): - results[future_to_idx[future]] = future.result() + results[future_to_idx.pop(future)] = future.result() + finally: + stopped = True + while future_to_idx: + future, idx = future_to_idx.popitem() + future.cancel() return results diff --git a/api/python/tests/test_data_transfer.py b/api/python/tests/test_data_transfer.py index c7808146a05..c0056f49d3c 100644 --- a/api/python/tests/test_data_transfer.py +++ b/api/python/tests/test_data_transfer.py @@ -1,10 +1,10 @@ """ Testing for data_transfer.py """ # Python imports +import hashlib import io import os import pathlib -import threading import time from contextlib import redirect_stderr @@ -524,37 +524,13 @@ class S3DownloadTest(QuiltTestCase): filename = 'some-file-name' dst = PhysicalKey(None, filename, None) - def _test_download(self, *, threshold, chunksize, parts=None, devnull=False): - num_parts = 1 if parts is None else len(parts) - barrier = threading.Barrier(num_parts, timeout=2) - - def side_effect(*args, **kwargs): - barrier.wait() # This ensures that we have concurrent calls to get_object(). - return { - 'VersionId': 'v1', - 'Body': io.BytesIO(self.data if parts is None else parts[kwargs['Range']]), - } - + def _test_download(self, *, threshold, chunksize, parts=data, devnull=False): dst = PhysicalKey(None, os.devnull, None) if devnull else self.dst - with mock.patch('quilt3.data_transfer.s3_transfer_config.max_request_concurrency', num_parts), \ - mock.patch('quilt3.data_transfer.s3_transfer_config.multipart_threshold', threshold), \ - mock.patch('quilt3.data_transfer.s3_transfer_config.multipart_chunksize', chunksize), \ - mock.patch.object(self.s3_client, 'get_object', side_effect=side_effect) as get_object_mock: - data_transfer.copy_file_list([(self.src, dst, self.size)]) - expected_params = { - 'Bucket': self.bucket, - 'Key': self.key, - } - - if parts is None: - get_object_mock.assert_called_once_with(**expected_params) - else: - get_object_mock.assert_has_calls([ - mock.call(**expected_params, Range=r) - for r in parts - ], any_order=True) - assert len(get_object_mock.call_args_list) == num_parts + with self.s3_test_multi_thread_download( + self.bucket, self.key, parts, threshold=threshold, chunksize=chunksize + ): + data_transfer.copy_file_list([(self.src, dst, self.size)]) if not devnull: with open(self.filename, 'rb') as f: @@ -584,3 +560,53 @@ def test_threshold_eq_chunk_eq_size(self): def test_threshold_eq_chunk_gt_size(self): self._test_download(threshold=self.size, chunksize=self.size + 1) + + +class S3HashingTest(QuiltTestCase): + data = b'0123456789abcdef' + size = len(data) + hasher = hashlib.sha256 + + bucket = 'test-bucket' + key = 'test-key' + src = PhysicalKey(bucket, key, None) + + def _hashing_subtest(self, *, threshold, chunksize, data=data): + with self.s3_test_multi_thread_download( + self.bucket, self.key, data, threshold=threshold, chunksize=chunksize + ): + assert data_transfer.calculate_sha256([self.src], [self.size]) == [self.hasher(self.data).hexdigest()] + + def test_single_request(self): + params = ( + (self.size + 1, 5), + (self.size, self.size), + (self.size, self.size + 1), + (5, self.size), + ) + for threshold, chunksize in params: + with self.subTest(threshold=threshold, chunksize=chunksize): + self._hashing_subtest(threshold=threshold, chunksize=chunksize) + + def test_multi_request(self): + params = ( + ( + self.size, 5, { + 'bytes=0-4': self.data[:5], + 'bytes=5-9': self.data[5:10], + 'bytes=10-14': self.data[10:15], + 'bytes=15-15': self.data[15:], + } + ), + ( + 5, self.size - 1, { + 'bytes=0-14': self.data[:15], + 'bytes=15-15': self.data[15:], + } + ), + ) + for threshold, chunksize, data in params: + for concurrency in (len(data), 1): + with mock.patch('quilt3.data_transfer.s3_transfer_config.max_request_concurrency', concurrency): + with self.subTest(threshold=threshold, chunksize=chunksize, data=data, concurrency=concurrency): + self._hashing_subtest(threshold=threshold, chunksize=chunksize, data=data) diff --git a/api/python/tests/utils.py b/api/python/tests/utils.py index 7fffe4d16ae..dedb30ac466 100644 --- a/api/python/tests/utils.py +++ b/api/python/tests/utils.py @@ -1,12 +1,16 @@ """ Unittest setup """ +import contextlib +import io import pathlib +import time from unittest import mock, TestCase import boto3 from botocore import UNSIGNED from botocore.client import Config +from botocore.response import StreamingBody from botocore.stub import Stubber import responses @@ -59,3 +63,50 @@ def tearDown(self): self.s3_stubber.deactivate() self.s3_client_patcher.stop() self.requests_mock.stop() + + def s3_streaming_body(self, data): + return StreamingBody(io.BytesIO(data), len(data)) + + @contextlib.contextmanager + def s3_test_multi_thread_download(self, bucket, key, data, *, threshold, chunksize): + """ + Helper for testing multi-thread download of a single file. + + data is either a bytes object if a single-request download is expected, + or a mapping like this: + { + 'bytes=0-4': b'part1', + 'bytes=5-9': b'part2', + ... + } + """ + is_single_request = isinstance(data, bytes) + num_parts = 1 if is_single_request else len(data) + expected_params = { + 'Bucket': bucket, + 'Key': key, + } + + def side_effect(*args, **kwargs): + body = self.s3_streaming_body(data if is_single_request else data[kwargs['Range']]) + if not is_single_request: + # This ensures that we have concurrent calls to get_object(). + time.sleep(0.1 * (1 - list(data).index(kwargs['Range']) / len(data))) + return { + 'VersionId': 'v1', + 'Body': body, + } + + with mock.patch('quilt3.data_transfer.s3_transfer_config.multipart_threshold', threshold), \ + mock.patch('quilt3.data_transfer.s3_transfer_config.multipart_chunksize', chunksize), \ + mock.patch.object(self.s3_client, 'get_object', side_effect=side_effect) as get_object_mock: + yield + + if is_single_request: + get_object_mock.assert_called_once_with(**expected_params) + else: + assert get_object_mock.call_count == num_parts + get_object_mock.assert_has_calls([ + mock.call(**expected_params, Range=r) + for r in data + ], any_order=True)