Skip to content

Commit

Permalink
Multi-thread hashing of large S3 files
Browse files Browse the repository at this point in the history
  • Loading branch information
sir-sigurd committed Sep 23, 2020
1 parent 2018aa2 commit 02b6709
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 68 deletions.
176 changes: 138 additions & 38 deletions api/python/quilt3/data_transfer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import itertools
import logging
import math
import os
import queue
import stat
import threading
import types
from collections import defaultdict, deque
from codecs import iterdecode
from concurrent.futures import ThreadPoolExecutor
Expand Down Expand Up @@ -33,6 +37,9 @@
MAX_FIX_HASH_RETRIES = 3


logger = logging.getLogger(__name__)


class S3Api(Enum):
GET_OBJECT = "GET_OBJECT"
HEAD_OBJECT = "HEAD_OBJECT"
Expand Down Expand Up @@ -859,6 +866,83 @@ def calculate_sha256(src_list: List[PhysicalKey], sizes: List[int]):
return _calculate_sha256_internal(src_list, sizes, [None] * len(src_list))


def _calculate_hash_get_s3_chunks(ctx, src, size):
params = dict(Bucket=src.bucket, Key=src.path)
if src.version_id is not None:
params.update(VersionId=src.version_id)
part_size = s3_transfer_config.multipart_chunksize
is_multi_part = (
size >= s3_transfer_config.multipart_threshold
and size > part_size
)
part_numbers = (
range(math.ceil(size / part_size))
if is_multi_part else
(None,)
)
s3_client = ctx.find_correct_client(S3Api.GET_OBJECT, src.bucket, params)

def read_to_queue(part_number, put_to_queue, stopped_event):
try:
logger.debug('%r part %s: download enqueued', src, part_number)
ctx.pending_parts_semaphore.acquire()
logger.debug('%r part %s: download started', src, part_number)
if part_number is not None:
start = part_number * part_size
end = min(start + part_size, size) - 1
part_params = dict(params, Range=f'bytes={start}-{end}')
else:
part_params = params

body = s3_client.get_object(**part_params)['Body']
for chunk in read_file_chunks(body):
put_to_queue(chunk)
if stopped_event.is_set():
logger.debug('%r part %s: stopped', src, part_number)
break

logger.debug('%r part %s: downloaded', src, part_number)
finally:
put_to_queue(None)

def iter_queue(part_number):
q = queue.Queue()
stopped_event = threading.Event()
f = ctx.executor.submit(read_to_queue, part_number, q.put_nowait, stopped_event)
try:
yield
yield from iter(q.get, None)
f.result()
logger.debug('%r part %s: processed', src, part_number)
except GeneratorExit:
if f.cancel():
logger.debug('%r part %s: cancelled', src, part_number)
else:
stopped_event.set()
finally:
if not f.cancelled():
ctx.pending_parts_semaphore.release()
logger.debug('%r part %s: semaphore released', src, part_number)

generators = deque()
for gen in map(iter_queue, part_numbers):
next(gen)
generators.append(gen)

return itertools.chain.from_iterable(
itertools.starmap(generators.popleft, itertools.repeat((), len(part_numbers))))


def with_lock(f):
lock = threading.Lock()

@functools.wraps(f)
def wrapper(*args, **kwargs):
with lock:
return f(*args, **kwargs)
return wrapper


@retry(stop=stop_after_attempt(MAX_FIX_HASH_RETRIES),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_result(lambda results: any(r is None or isinstance(r, Exception) for r in results)),
Expand All @@ -870,53 +954,69 @@ def _calculate_sha256_internal(src_list, sizes, results):
for size, result in zip(sizes, results)
if result is None or isinstance(result, Exception)
)
lock = Lock()
# This controls how many parts can be stored in the memory.
# This includes the ones that are being downloaded or hashed.
# The number was chosen empirically.
s3_max_pending_parts = s3_transfer_config.max_request_concurrency * 4
stopped = False

with tqdm(desc="Hashing", total=total_size, unit='B', unit_scale=True, disable=DISABLE_TQDM) as progress:
def _process_url(src, size):
hash_obj = hashlib.sha256()
if src.is_local():
with open(src.path, 'rb') as fd:
for chunk in read_file_chunks(fd):
hash_obj.update(chunk)
with lock:
progress.update(len(chunk))

current_file_size = fd.tell()
if current_file_size != size:
warnings.warn(
f"Expected the package entry at {src!r} to be {size} B in size, but "
f"found an object which is {current_file_size} B instead. This "
f"indicates that the content of the file changed in between when you "
f"included this entry in the package (via set or set_dir) and now. "
f"This should be avoided if possible."
)
def get_file_chunks(src, size):
with open(src.path, 'rb') as file:
yield from read_file_chunks(file)

current_file_size = file.tell()
if current_file_size != size:
warnings.warn(
f"Expected the package entry at {src!r} to be {size} B in size, but "
f"found an object which is {current_file_size} B instead. This "
f"indicates that the content of the file changed in between when you "
f"included this entry in the package (via set or set_dir) and now. "
f"This should be avoided if possible."
)

else:
params = dict(Bucket=src.bucket, Key=src.path)
if src.version_id is not None:
params.update(VersionId=src.version_id)
try:
s3_client = S3ClientProvider().find_correct_client(S3Api.GET_OBJECT, src.bucket, params)

resp = s3_client.get_object(**params)
body = resp['Body']
for chunk in read_file_chunks(body):
hash_obj.update(chunk)
with lock:
progress.update(len(chunk))
except (ConnectionError, HTTPClientError, ReadTimeoutError) as e:
return e
return hash_obj.hexdigest()

with ThreadPoolExecutor() as executor:
def _process_url(src, size):
hash_obj = hashlib.sha256()

generator, exceptions_to_retry = (
(get_file_chunks(src, size), ())
if src.is_local() else
(
_calculate_hash_get_s3_chunks(s3_context, src, size),
(ConnectionError, HTTPClientError, ReadTimeoutError)
)
)
try:
for chunk in generator:
hash_obj.update(chunk)
progress_update(len(chunk))
if stopped:
return
except exceptions_to_retry as e:
return e
return hash_obj.hexdigest()

with tqdm(desc="Hashing", total=total_size, unit='B', unit_scale=True, disable=DISABLE_TQDM) as progress, \
ThreadPoolExecutor() as executor, \
ThreadPoolExecutor(
s3_transfer_config.max_request_concurrency,
thread_name_prefix='s3-executor',
) as s3_executor:
s3_context = types.SimpleNamespace(
find_correct_client=with_lock(S3ClientProvider().find_correct_client),
pending_parts_semaphore=threading.BoundedSemaphore(s3_max_pending_parts),
executor=s3_executor,
)
progress_update = with_lock(progress.update)
try:
future_to_idx = {
executor.submit(_process_url, src, size): i
for i, (src, size, result) in enumerate(zip(src_list, sizes, results))
if result is None or isinstance(result, Exception)
}
for future in concurrent.futures.as_completed(future_to_idx):
results[future_to_idx[future]] = future.result()
finally:
stopped = True

return results

Expand Down
86 changes: 56 additions & 30 deletions api/python/tests/test_data_transfer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
""" Testing for data_transfer.py """

# Python imports
import hashlib
import io
import os
import pathlib
import threading
import time
from contextlib import redirect_stderr

Expand Down Expand Up @@ -524,37 +524,13 @@ class S3DownloadTest(QuiltTestCase):
filename = 'some-file-name'
dst = PhysicalKey(None, filename, None)

def _test_download(self, *, threshold, chunksize, parts=None, devnull=False):
num_parts = 1 if parts is None else len(parts)
barrier = threading.Barrier(num_parts, timeout=2)

def side_effect(*args, **kwargs):
barrier.wait() # This ensures that we have concurrent calls to get_object().
return {
'VersionId': 'v1',
'Body': io.BytesIO(self.data if parts is None else parts[kwargs['Range']]),
}

def _test_download(self, *, threshold, chunksize, parts=data, devnull=False):
dst = PhysicalKey(None, os.devnull, None) if devnull else self.dst
with mock.patch('quilt3.data_transfer.s3_transfer_config.max_request_concurrency', num_parts), \
mock.patch('quilt3.data_transfer.s3_transfer_config.multipart_threshold', threshold), \
mock.patch('quilt3.data_transfer.s3_transfer_config.multipart_chunksize', chunksize), \
mock.patch.object(self.s3_client, 'get_object', side_effect=side_effect) as get_object_mock:
data_transfer.copy_file_list([(self.src, dst, self.size)])

expected_params = {
'Bucket': self.bucket,
'Key': self.key,
}

if parts is None:
get_object_mock.assert_called_once_with(**expected_params)
else:
get_object_mock.assert_has_calls([
mock.call(**expected_params, Range=r)
for r in parts
], any_order=True)
assert len(get_object_mock.call_args_list) == num_parts
with self.s3_test_multi_thread_download(
self.bucket, self.key, parts, threshold=threshold, chunksize=chunksize
):
data_transfer.copy_file_list([(self.src, dst, self.size)])

if not devnull:
with open(self.filename, 'rb') as f:
Expand Down Expand Up @@ -584,3 +560,53 @@ def test_threshold_eq_chunk_eq_size(self):

def test_threshold_eq_chunk_gt_size(self):
self._test_download(threshold=self.size, chunksize=self.size + 1)


class S3HashingTest(QuiltTestCase):
data = b'0123456789abcdef'
size = len(data)
hasher = hashlib.sha256

bucket = 'test-bucket'
key = 'test-key'
src = PhysicalKey(bucket, key, None)

def _hashing_subtest(self, *, threshold, chunksize, data=data):
with self.s3_test_multi_thread_download(
self.bucket, self.key, data, threshold=threshold, chunksize=chunksize
):
assert data_transfer.calculate_sha256([self.src], [self.size]) == [self.hasher(self.data).hexdigest()]

def test_single_request(self):
params = (
(self.size + 1, 5),
(self.size, self.size),
(self.size, self.size + 1),
(5, self.size),
)
for threshold, chunksize in params:
with self.subTest(threshold=threshold, chunksize=chunksize):
self._hashing_subtest(threshold=threshold, chunksize=chunksize)

def test_multi_request(self):
params = (
(
self.size, 5, {
'bytes=0-4': self.data[:5],
'bytes=5-9': self.data[5:10],
'bytes=10-14': self.data[10:15],
'bytes=15-15': self.data[15:],
}
),
(
5, self.size - 1, {
'bytes=0-14': self.data[:15],
'bytes=15-15': self.data[15:],
}
),
)
for threshold, chunksize, data in params:
for concurrency in (len(data), 1):
with mock.patch('quilt3.data_transfer.s3_transfer_config.max_request_concurrency', concurrency):
with self.subTest(threshold=threshold, chunksize=chunksize, data=data, concurrency=concurrency):
self._hashing_subtest(threshold=threshold, chunksize=chunksize, data=data)
51 changes: 51 additions & 0 deletions api/python/tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
"""
Unittest setup
"""
import contextlib
import io
import pathlib
import time
from unittest import mock, TestCase

import boto3
from botocore import UNSIGNED
from botocore.client import Config
from botocore.response import StreamingBody
from botocore.stub import Stubber
import responses

Expand Down Expand Up @@ -59,3 +63,50 @@ def tearDown(self):
self.s3_stubber.deactivate()
self.s3_client_patcher.stop()
self.requests_mock.stop()

def s3_streaming_body(self, data):
return StreamingBody(io.BytesIO(data), len(data))

@contextlib.contextmanager
def s3_test_multi_thread_download(self, bucket, key, data, *, threshold, chunksize):
"""
Helper for testing multi-thread download of a single file.
data is either a bytes object if a single-request download is expected,
or a mapping like this:
{
'bytes=0-4': b'part1',
'bytes=5-9': b'part2',
...
}
"""
is_single_request = isinstance(data, bytes)
num_parts = 1 if is_single_request else len(data)
expected_params = {
'Bucket': bucket,
'Key': key,
}

def side_effect(*args, **kwargs):
body = self.s3_streaming_body(data if is_single_request else data[kwargs['Range']])
if not is_single_request:
# This ensures that we have concurrent calls to get_object().
time.sleep(0.1 * (1 - list(data).index(kwargs['Range']) / len(data)))
return {
'VersionId': 'v1',
'Body': body,
}

with mock.patch('quilt3.data_transfer.s3_transfer_config.multipart_threshold', threshold), \
mock.patch('quilt3.data_transfer.s3_transfer_config.multipart_chunksize', chunksize), \
mock.patch.object(self.s3_client, 'get_object', side_effect=side_effect) as get_object_mock:
yield

if is_single_request:
get_object_mock.assert_called_once_with(**expected_params)
else:
assert get_object_mock.call_count == num_parts
get_object_mock.assert_has_calls([
mock.call(**expected_params, Range=r)
for r in data
], any_order=True)

0 comments on commit 02b6709

Please sign in to comment.