Skip to content

Commit

Permalink
Multi-thread hashing of large S3 files
Browse files Browse the repository at this point in the history
  • Loading branch information
sir-sigurd committed Sep 23, 2020
1 parent 2018aa2 commit 1056a19
Show file tree
Hide file tree
Showing 3 changed files with 263 additions and 73 deletions.
199 changes: 156 additions & 43 deletions api/python/quilt3/data_transfer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import itertools
import logging
import math
import os
import queue
import stat
import threading
import types
from collections import defaultdict, deque
from codecs import iterdecode
from concurrent.futures import ThreadPoolExecutor
Expand Down Expand Up @@ -33,6 +37,9 @@
MAX_FIX_HASH_RETRIES = 3


logger = logging.getLogger(__name__)


class S3Api(Enum):
GET_OBJECT = "GET_OBJECT"
HEAD_OBJECT = "HEAD_OBJECT"
Expand Down Expand Up @@ -859,6 +866,88 @@ def calculate_sha256(src_list: List[PhysicalKey], sizes: List[int]):
return _calculate_sha256_internal(src_list, sizes, [None] * len(src_list))


def _calculate_hash_get_s3_chunks(ctx, src, size):
params = dict(Bucket=src.bucket, Key=src.path)
if src.version_id is not None:
params.update(VersionId=src.version_id)
part_size = s3_transfer_config.multipart_chunksize
is_multi_part = (
size >= s3_transfer_config.multipart_threshold
and size > part_size
)
part_numbers = (
range(math.ceil(size / part_size))
if is_multi_part else
(None,)
)
s3_client = ctx.find_correct_client(S3Api.GET_OBJECT, src.bucket, params)

def read_to_queue(part_number, put_to_queue, stopped_event):
try:
logger.debug('%r part %s: download enqueued', src, part_number)
# This semaphore is released in iter_queue() when the part is fully
# downloaded and all chunks are retrieved from the queue or if download
# fails.
ctx.pending_parts_semaphore.acquire()
logger.debug('%r part %s: download started', src, part_number)
if part_number is not None:
start = part_number * part_size
end = min(start + part_size, size) - 1
part_params = dict(params, Range=f'bytes={start}-{end}')
else:
part_params = params

body = s3_client.get_object(**part_params)['Body']
for chunk in read_file_chunks(body):
put_to_queue(chunk)
if stopped_event.is_set():
logger.debug('%r part %s: stopped', src, part_number)
break

logger.debug('%r part %s: downloaded', src, part_number)
finally:
put_to_queue(None)

def iter_queue(part_number):
q = queue.Queue()
stopped_event = threading.Event()
f = ctx.executor.submit(read_to_queue, part_number, q.put_nowait, stopped_event)
try:
yield
yield from iter(q.get, None)
f.result()
logger.debug('%r part %s: processed', src, part_number)
except GeneratorExit:
if f.cancel():
logger.debug('%r part %s: cancelled', src, part_number)
else:
stopped_event.set()
finally:
if not f.cancelled():
ctx.pending_parts_semaphore.release()
logger.debug('%r part %s: semaphore released', src, part_number)

generators = deque()
for gen in map(iter_queue, part_numbers):
# Step into generator, so it will receive GeneratorExit when it's closed
# or garbage collected.
next(gen)
generators.append(gen)

return itertools.chain.from_iterable(
itertools.starmap(generators.popleft, itertools.repeat((), len(part_numbers))))


def with_lock(f):
lock = threading.Lock()

@functools.wraps(f)
def wrapper(*args, **kwargs):
with lock:
return f(*args, **kwargs)
return wrapper


@retry(stop=stop_after_attempt(MAX_FIX_HASH_RETRIES),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_result(lambda results: any(r is None or isinstance(r, Exception) for r in results)),
Expand All @@ -870,53 +959,77 @@ def _calculate_sha256_internal(src_list, sizes, results):
for size, result in zip(sizes, results)
if result is None or isinstance(result, Exception)
)
lock = Lock()
# This controls how many parts can be stored in the memory.
# This includes the ones that are being downloaded or hashed.
# The number was chosen empirically.
s3_max_pending_parts = s3_transfer_config.max_request_concurrency * 4
stopped = False

with tqdm(desc="Hashing", total=total_size, unit='B', unit_scale=True, disable=DISABLE_TQDM) as progress:
def _process_url(src, size):
hash_obj = hashlib.sha256()
if src.is_local():
with open(src.path, 'rb') as fd:
for chunk in read_file_chunks(fd):
hash_obj.update(chunk)
with lock:
progress.update(len(chunk))

current_file_size = fd.tell()
if current_file_size != size:
warnings.warn(
f"Expected the package entry at {src!r} to be {size} B in size, but "
f"found an object which is {current_file_size} B instead. This "
f"indicates that the content of the file changed in between when you "
f"included this entry in the package (via set or set_dir) and now. "
f"This should be avoided if possible."
)
def get_file_chunks(src, size):
with open(src.path, 'rb') as file:
yield from read_file_chunks(file)

current_file_size = file.tell()
if current_file_size != size:
warnings.warn(
f"Expected the package entry at {src!r} to be {size} B in size, but "
f"found an object which is {current_file_size} B instead. This "
f"indicates that the content of the file changed in between when you "
f"included this entry in the package (via set or set_dir) and now. "
f"This should be avoided if possible."
)

else:
params = dict(Bucket=src.bucket, Key=src.path)
if src.version_id is not None:
params.update(VersionId=src.version_id)
try:
s3_client = S3ClientProvider().find_correct_client(S3Api.GET_OBJECT, src.bucket, params)

resp = s3_client.get_object(**params)
body = resp['Body']
for chunk in read_file_chunks(body):
hash_obj.update(chunk)
with lock:
progress.update(len(chunk))
except (ConnectionError, HTTPClientError, ReadTimeoutError) as e:
return e
return hash_obj.hexdigest()
def _process_url(src, size):
hash_obj = hashlib.sha256()

with ThreadPoolExecutor() as executor:
future_to_idx = {
executor.submit(_process_url, src, size): i
for i, (src, size, result) in enumerate(zip(src_list, sizes, results))
if result is None or isinstance(result, Exception)
}
generator, exceptions_to_retry = (
(get_file_chunks(src, size), ())
if src.is_local() else
(
_calculate_hash_get_s3_chunks(s3_context, src, size),
(ConnectionError, HTTPClientError, ReadTimeoutError)
)
)
try:
for chunk in generator:
hash_obj.update(chunk)
progress_update(len(chunk))
if stopped:
return
except exceptions_to_retry as e:
return e
else:
return hash_obj.hexdigest()
finally:
# We want this generator to be finished immediately,
# so it finishes its own tasks.
del generator

with tqdm(desc="Hashing", total=total_size, unit='B', unit_scale=True, disable=DISABLE_TQDM) as progress, \
ThreadPoolExecutor() as executor, \
ThreadPoolExecutor(
s3_transfer_config.max_request_concurrency,
thread_name_prefix='s3-executor',
) as s3_executor:
s3_context = types.SimpleNamespace(
find_correct_client=with_lock(S3ClientProvider().find_correct_client),
pending_parts_semaphore=threading.BoundedSemaphore(s3_max_pending_parts),
executor=s3_executor,
)
progress_update = with_lock(progress.update)
future_to_idx = {
executor.submit(_process_url, src, size): i
for i, (src, size, result) in enumerate(zip(src_list, sizes, results))
if result is None or isinstance(result, Exception)
}
try:
for future in concurrent.futures.as_completed(future_to_idx):
results[future_to_idx[future]] = future.result()
results[future_to_idx.pop(future)] = future.result()
finally:
stopped = True
while future_to_idx:
future, idx = future_to_idx.popitem()
future.cancel()

return results

Expand Down
86 changes: 56 additions & 30 deletions api/python/tests/test_data_transfer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
""" Testing for data_transfer.py """

# Python imports
import hashlib
import io
import os
import pathlib
import threading
import time
from contextlib import redirect_stderr

Expand Down Expand Up @@ -524,37 +524,13 @@ class S3DownloadTest(QuiltTestCase):
filename = 'some-file-name'
dst = PhysicalKey(None, filename, None)

def _test_download(self, *, threshold, chunksize, parts=None, devnull=False):
num_parts = 1 if parts is None else len(parts)
barrier = threading.Barrier(num_parts, timeout=2)

def side_effect(*args, **kwargs):
barrier.wait() # This ensures that we have concurrent calls to get_object().
return {
'VersionId': 'v1',
'Body': io.BytesIO(self.data if parts is None else parts[kwargs['Range']]),
}

def _test_download(self, *, threshold, chunksize, parts=data, devnull=False):
dst = PhysicalKey(None, os.devnull, None) if devnull else self.dst
with mock.patch('quilt3.data_transfer.s3_transfer_config.max_request_concurrency', num_parts), \
mock.patch('quilt3.data_transfer.s3_transfer_config.multipart_threshold', threshold), \
mock.patch('quilt3.data_transfer.s3_transfer_config.multipart_chunksize', chunksize), \
mock.patch.object(self.s3_client, 'get_object', side_effect=side_effect) as get_object_mock:
data_transfer.copy_file_list([(self.src, dst, self.size)])

expected_params = {
'Bucket': self.bucket,
'Key': self.key,
}

if parts is None:
get_object_mock.assert_called_once_with(**expected_params)
else:
get_object_mock.assert_has_calls([
mock.call(**expected_params, Range=r)
for r in parts
], any_order=True)
assert len(get_object_mock.call_args_list) == num_parts
with self.s3_test_multi_thread_download(
self.bucket, self.key, parts, threshold=threshold, chunksize=chunksize
):
data_transfer.copy_file_list([(self.src, dst, self.size)])

if not devnull:
with open(self.filename, 'rb') as f:
Expand Down Expand Up @@ -584,3 +560,53 @@ def test_threshold_eq_chunk_eq_size(self):

def test_threshold_eq_chunk_gt_size(self):
self._test_download(threshold=self.size, chunksize=self.size + 1)


class S3HashingTest(QuiltTestCase):
data = b'0123456789abcdef'
size = len(data)
hasher = hashlib.sha256

bucket = 'test-bucket'
key = 'test-key'
src = PhysicalKey(bucket, key, None)

def _hashing_subtest(self, *, threshold, chunksize, data=data):
with self.s3_test_multi_thread_download(
self.bucket, self.key, data, threshold=threshold, chunksize=chunksize
):
assert data_transfer.calculate_sha256([self.src], [self.size]) == [self.hasher(self.data).hexdigest()]

def test_single_request(self):
params = (
(self.size + 1, 5),
(self.size, self.size),
(self.size, self.size + 1),
(5, self.size),
)
for threshold, chunksize in params:
with self.subTest(threshold=threshold, chunksize=chunksize):
self._hashing_subtest(threshold=threshold, chunksize=chunksize)

def test_multi_request(self):
params = (
(
self.size, 5, {
'bytes=0-4': self.data[:5],
'bytes=5-9': self.data[5:10],
'bytes=10-14': self.data[10:15],
'bytes=15-15': self.data[15:],
}
),
(
5, self.size - 1, {
'bytes=0-14': self.data[:15],
'bytes=15-15': self.data[15:],
}
),
)
for threshold, chunksize, data in params:
for concurrency in (len(data), 1):
with mock.patch('quilt3.data_transfer.s3_transfer_config.max_request_concurrency', concurrency):
with self.subTest(threshold=threshold, chunksize=chunksize, data=data, concurrency=concurrency):
self._hashing_subtest(threshold=threshold, chunksize=chunksize, data=data)
Loading

0 comments on commit 1056a19

Please sign in to comment.