-
Notifications
You must be signed in to change notification settings - Fork 90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement semi-atomic push #2689
Changes from 6 commits
07eba9a
a9b3f04
b90800c
90ac4bc
3b7a138
f8f3de1
57f69dd
f850716
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
from collections import deque | ||
from multiprocessing import Pool | ||
|
||
import botocore.exceptions | ||
import jsonlines | ||
from tqdm import tqdm | ||
|
||
|
@@ -375,11 +376,10 @@ def object_pairs_hook(items): | |
class Package: | ||
""" In-memory representation of a package """ | ||
|
||
_origin = None | ||
|
||
def __init__(self): | ||
self._children = {} | ||
self._meta = {'version': 'v0'} | ||
self._origin = None | ||
|
||
@ApiTelemetry("package.__repr__") | ||
def __repr__(self, max_lines=20): | ||
|
@@ -604,7 +604,9 @@ def download_manifest(dst): | |
stack.callback(os.unlink, local_pkg_manifest) | ||
download_manifest(local_pkg_manifest) | ||
|
||
return cls._from_path(local_pkg_manifest) | ||
pkg = cls._from_path(local_pkg_manifest) | ||
pkg._origin = PackageRevInfo(str(registry.base), name, top_hash) | ||
return pkg | ||
|
||
@classmethod | ||
def _from_path(cls, path): | ||
|
@@ -1039,8 +1041,7 @@ def _build(self, name, registry, message): | |
def _push_manifest(self, name, registry, top_hash): | ||
manifest = io.BytesIO() | ||
self._dump(manifest) | ||
self._timestamp = registry.push_manifest(name, top_hash, manifest.getvalue()) | ||
self._origin = PackageRevInfo(str(registry.base), name, top_hash) | ||
registry.push_manifest(name, top_hash, manifest.getvalue()) | ||
|
||
@ApiTelemetry("package.dump") | ||
def dump(self, writable_file): | ||
|
@@ -1290,7 +1291,7 @@ def _get_top_hash_parts(cls, meta, entries): | |
|
||
@ApiTelemetry("package.push") | ||
@_fix_docstring(workflow=_WORKFLOW_PARAM_DOCSTRING) | ||
def push(self, name, registry=None, dest=None, message=None, selector_fn=None, *, workflow=...): | ||
def push(self, name, registry=None, dest=None, message=None, selector_fn=None, *, workflow=..., force=False): | ||
""" | ||
Copies objects to path, then creates a new package that points to those objects. | ||
Copies each object in this package to path according to logical key structure, | ||
|
@@ -1313,6 +1314,9 @@ def push(self, name, registry=None, dest=None, message=None, selector_fn=None, * | |
If `selector_fn('entry_1', pkg["entry_1"]) == True`, | ||
`new_pkg["entry_1"] = ["s3://bucket/prefix/entry_1.json"]` | ||
|
||
By default, push will not overwrite an existing package if its top hash does not match | ||
the parent hash of the package being pushed. Use `force=True` to skip the check. | ||
|
||
Args: | ||
name: name for package in registry | ||
dest: where to copy the objects in the package | ||
|
@@ -1326,13 +1330,14 @@ def push(self, name, registry=None, dest=None, message=None, selector_fn=None, * | |
are spread over multiple buckets and you add a single local file, you can use selector_fn to | ||
only push the local file to s3 (instead of pushing all data to the destination bucket). | ||
%(workflow)s | ||
force: skip the top hash check and overwrite any existing package | ||
|
||
Returns: | ||
A new package that points to the copied objects. | ||
""" | ||
return self._push(name, registry, dest, message, selector_fn, workflow=workflow, print_info=True) | ||
return self._push(name, registry, dest, message, selector_fn, workflow=workflow, print_info=True, force=force) | ||
|
||
def _push(self, name, registry=None, dest=None, message=None, selector_fn=None, *, workflow, print_info): | ||
def _push(self, name, registry=None, dest=None, message=None, selector_fn=None, *, workflow, print_info, force): | ||
if selector_fn is None: | ||
def selector_fn(*args): | ||
return True | ||
|
@@ -1391,12 +1396,34 @@ def dest_fn(lk, *args, **kwargs): | |
registry = get_package_registry(registry) | ||
self._validate_with_workflow(registry=registry, workflow=workflow, name=name, message=message) | ||
|
||
def check_latest_hash(): | ||
if force: | ||
return | ||
|
||
try: | ||
latest_hash = get_bytes(registry.pointer_latest_pk(name)).decode() | ||
except botocore.exceptions.ClientError as ex: | ||
if ex.response['Error']['Code'] == 'NoSuchKey': | ||
# Expected | ||
return | ||
raise | ||
|
||
if self._origin is None or latest_hash != self._origin.top_hash: | ||
raise QuiltException( | ||
f"Package with an unexpected hash {latest_hash} already exists at the destination; " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might make sense to specify both expected and actual hashes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok |
||
"use force=True to overwrite" | ||
) | ||
|
||
# Check the top hash and fail early if it's unexpected. | ||
check_latest_hash() | ||
|
||
self._fix_sha256() | ||
|
||
pkg = self.__class__() | ||
pkg._meta = self._meta | ||
pkg._set_commit_message(message) | ||
top_hash = self._calculate_top_hash(pkg._meta, self.walk()) | ||
pkg._origin = PackageRevInfo(str(registry.base), name, top_hash) | ||
|
||
# Since all that is modified is physical keys, pkg will have the same top hash | ||
file_list = [] | ||
|
@@ -1446,6 +1473,9 @@ def physical_key_is_temp_file(pk): | |
for lk in temp_file_logical_keys: | ||
self._set(lk, pkg[lk]) | ||
|
||
# Check top hash again just before pushing, to minimize the race condition. | ||
check_latest_hash() | ||
|
||
pkg._push_manifest(name, registry, top_hash) | ||
|
||
if print_info: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
from collections import Counter | ||
from contextlib import redirect_stderr | ||
from datetime import datetime | ||
from functools import partial | ||
from io import BytesIO | ||
from pathlib import Path | ||
from unittest import mock | ||
|
@@ -89,6 +90,17 @@ def setup_s3_stubber_resolve_pointer(self, pkg_registry, pkg_name, *, pointer, t | |
} | ||
) | ||
|
||
def setup_s3_stubber_resolve_pointer_not_found(self, pkg_registry, pkg_name, *, pointer): | ||
self.s3_stubber.add_client_error( | ||
method='get_object', | ||
service_error_code='NoSuchKey', | ||
http_status_code=404, | ||
expected_params={ | ||
'Bucket': pkg_registry.root.bucket, | ||
'Key': pkg_registry.pointer_pk(pkg_name, pointer).path, | ||
} | ||
) | ||
|
||
def setup_s3_stubber_delete_pointer(self, pkg_registry, pkg_name, *, pointer): | ||
self.s3_stubber.add_response( | ||
method='delete_object', | ||
|
@@ -462,7 +474,7 @@ def add_pkg_file(pkg, lk, filename, data, *, version): | |
) | ||
with patch('time.time', return_value=timestamp1), \ | ||
patch('quilt3.data_transfer.MAX_CONCURRENCY', 1): | ||
remote_pkg = new_pkg.push(pkg_name, registry) | ||
remote_pkg = new_pkg.push(pkg_name, registry, force=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because otherwise, it makes extra S3 calls, so I'd need to mock those. I guess There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Well, yes, it's simpler, but are we going to use |
||
|
||
# Modify one file, and check that only that file gets uploaded. | ||
add_pkg_file(remote_pkg, 'foo2', 'bar3', '!!!', version='v2') | ||
|
@@ -480,7 +492,7 @@ def add_pkg_file(pkg, lk, filename, data, *, version): | |
stderr = io.StringIO() | ||
|
||
with redirect_stderr(stderr), patch('quilt3.packages.DISABLE_TQDM', True): | ||
remote_pkg.push(pkg_name, registry) | ||
remote_pkg.push(pkg_name, registry, force=True) | ||
assert not stderr.getvalue() | ||
|
||
def test_package_deserialize(self): | ||
|
@@ -717,7 +729,7 @@ def test_set_package_entry_as_object(self): | |
# Test that push cleans up the temporary files, if and only if the serialization_location was not set | ||
with patch('quilt3.Package._push_manifest'), \ | ||
patch('quilt3.packages.copy_file_list', _mock_copy_file_list): | ||
pkg.push('Quilt/test_pkg_name', 's3://test-bucket') | ||
pkg.push('Quilt/test_pkg_name', 's3://test-bucket', force=True) | ||
|
||
for lk in ["mydataframe1.parquet", "mydataframe2.csv", "mydataframe3.tsv"]: | ||
file_path = pkg[lk].physical_key.path | ||
|
@@ -1059,19 +1071,61 @@ def test_push_restrictions(self): | |
|
||
# disallow pushing not to the top level of a remote S3 registry | ||
with pytest.raises(QuiltException): | ||
p.push('Quilt/Test', 's3://test-bucket/foo/bar') | ||
p.push('Quilt/Test', 's3://test-bucket/foo/bar', force=True) | ||
|
||
# disallow pushing to the local filesystem (use install instead) | ||
with pytest.raises(QuiltException): | ||
p.push('Quilt/Test', './') | ||
p.push('Quilt/Test', './', force=True) | ||
|
||
# disallow pushing the package manifest to remote but package data to local | ||
with pytest.raises(QuiltException): | ||
p.push('Quilt/Test', 's3://test-bucket', dest='./') | ||
p.push('Quilt/Test', 's3://test-bucket', dest='./', force=True) | ||
|
||
# disallow pushing the pacakge manifest to remote but package data to a different remote | ||
with pytest.raises(QuiltException): | ||
p.push('Quilt/Test', 's3://test-bucket', dest='s3://other-test-bucket') | ||
p.push('Quilt/Test', 's3://test-bucket', dest='s3://other-test-bucket', force=True) | ||
|
||
@patch('quilt3.workflows.validate', mock.MagicMock(return_value=None)) | ||
def test_push_conflicts(self): | ||
registry = 's3://test-bucket' | ||
pkg_registry = self.S3PackageRegistryDefault(PhysicalKey.from_url(registry)) | ||
pkg_name = 'Quilt/test' | ||
|
||
pkg = Package() | ||
|
||
self.patch_s3_registry('shorten_top_hash', return_value='123456') | ||
|
||
with patch('quilt3.packages.copy_file_list', _mock_copy_file_list), \ | ||
patch('quilt3.Package._push_manifest'): | ||
# Remote package does not yet exist: push succeeds. | ||
|
||
for _ in range(2): | ||
self.setup_s3_stubber_resolve_pointer_not_found( | ||
pkg_registry, pkg_name, pointer='latest' | ||
) | ||
|
||
pkg2 = pkg.push('Quilt/test', 's3://test-bucket') | ||
|
||
# Remote package exists, but has the parent hash: push succeeds. | ||
|
||
pkg2.set('foo', b'123') | ||
pkg2.build('Quilt/test') | ||
|
||
for _ in range(2): | ||
self.setup_s3_stubber_resolve_pointer( | ||
pkg_registry, pkg_name, pointer='latest', top_hash=pkg.top_hash | ||
) | ||
|
||
pkg2.push('Quilt/test', 's3://test-bucket') | ||
|
||
# Remote package exists and the hash does not match: push fails. | ||
|
||
self.setup_s3_stubber_resolve_pointer( | ||
pkg_registry, pkg_name, pointer='latest', top_hash=pkg2.top_hash | ||
) | ||
|
||
with self.assertRaisesRegex(QuiltException, 'Package with an unexpected hash'): | ||
pkg2.push('Quilt/test', 's3://test-bucket') | ||
|
||
@patch('quilt3.workflows.validate', return_value=None) | ||
def test_commit_message_on_push(self, mocked_workflow_validate): | ||
|
@@ -1083,7 +1137,7 @@ def test_commit_message_on_push(self, mocked_workflow_validate): | |
with open(REMOTE_MANIFEST, encoding='utf-8') as fd: | ||
pkg = Package.load(fd) | ||
|
||
pkg.push('Quilt/test_pkg_name', 's3://test-bucket', message='test_message') | ||
pkg.push('Quilt/test_pkg_name', 's3://test-bucket', message='test_message', force=True) | ||
registry = self.S3PackageRegistryDefault(PhysicalKey.from_url('s3://test-bucket')) | ||
message = 'test_message' | ||
push_manifest_mock.assert_called_once_with( | ||
|
@@ -1186,7 +1240,7 @@ def test_manifest(self): | |
@patch('quilt3.workflows.validate', mock.MagicMock(return_value='workflow data')) | ||
def test_manifest_workflow(self): | ||
self.patch_s3_registry('shorten_top_hash', return_value='7a67ff4') | ||
for method in (Package.build, Package.push): | ||
for method in (Package.build, partial(Package.push, force=True)): | ||
with self.subTest(method=method): | ||
pkg = Package() | ||
method(pkg, 'foo/bar', registry='s3://test-bucket') | ||
|
@@ -1610,7 +1664,7 @@ def test_workflow_validation(self, workflow_validate_mock, copy_file_list_mock): | |
pkg_registry = self.S3PackageRegistryDefault(PhysicalKey.from_url('s3://test-bucket')) | ||
self.patch_s3_registry('shorten_top_hash', return_value='7a67ff4') | ||
|
||
for method in (Package.build, Package.push): | ||
for method in (Package.build, partial(Package.push, force=True)): | ||
with self.subTest(method=method): | ||
with patch('quilt3.Package._push_manifest') as push_manifest_mock: | ||
pkg = Package().set('foo', DATA_DIR / 'foo.txt') | ||
|
@@ -1625,7 +1679,7 @@ def test_workflow_validation(self, workflow_validate_mock, copy_file_list_mock): | |
assert pkg._workflow is mock.sentinel.returned_workflow | ||
push_manifest_mock.assert_called_once() | ||
workflow_validate_mock.reset_mock() | ||
if method is Package.push: | ||
if method is not Package.build: | ||
copy_file_list_mock.assert_called_once() | ||
copy_file_list_mock.reset_mock() | ||
|
||
|
@@ -1649,7 +1703,7 @@ def test_workflow_validation(self, workflow_validate_mock, copy_file_list_mock): | |
assert pkg._workflow is mock.sentinel.returned_workflow | ||
push_manifest_mock.assert_called_once() | ||
workflow_validate_mock.reset_mock() | ||
if method is Package.push: | ||
if method is not Package.build: | ||
copy_file_list_mock.assert_called_once() | ||
copy_file_list_mock.reset_mock() | ||
|
||
|
@@ -1660,7 +1714,7 @@ def test_push_dest_fn_non_string(self): | |
with self.subTest(value=val): | ||
with pytest.raises(TypeError) as excinfo: | ||
pkg.push('foo/bar', registry='s3://test-bucket', | ||
dest=(lambda v: lambda *args, **kwargs: v)(val)) | ||
dest=(lambda v: lambda *args, **kwargs: v)(val), force=True) | ||
assert 'str is expected' in str(excinfo.value) | ||
|
||
@patch('quilt3.workflows.validate', mock.MagicMock(return_value=None)) | ||
|
@@ -1670,13 +1724,16 @@ def test_push_dest_fn_non_supported_uri(self): | |
with self.subTest(value=val): | ||
with pytest.raises(quilt3.util.URLParseError): | ||
pkg.push('foo/bar', registry='s3://test-bucket', | ||
dest=(lambda v: lambda *args, **kwargs: v)(val)) | ||
dest=(lambda v: lambda *args, **kwargs: v)(val), force=True) | ||
|
||
@patch('quilt3.workflows.validate', mock.MagicMock(return_value=None)) | ||
def test_push_dest_fn_s3_uri_with_version_id(self): | ||
pkg = Package().set('foo', DATA_DIR / 'foo.txt') | ||
with pytest.raises(ValueError) as excinfo: | ||
pkg.push('foo/bar', registry='s3://test-bucket', dest=lambda *args, **kwargs: 's3://bucket/ds?versionId=v') | ||
pkg.push( | ||
'foo/bar', registry='s3://test-bucket', | ||
dest=lambda *args, **kwargs: 's3://bucket/ds?versionId=v', force=True | ||
) | ||
assert 'URI must not include versionId' in str(excinfo.value) | ||
|
||
@patch('quilt3.workflows.validate', mock.MagicMock(return_value=None)) | ||
|
@@ -1703,7 +1760,7 @@ def test_push_dest_fn(self): | |
) | ||
push_manifest_mock = self.patch_s3_registry('push_manifest') | ||
self.patch_s3_registry('shorten_top_hash', return_value='7a67ff4') | ||
pkg.push(pkg_name, registry='s3://test-bucket', dest=dest_fn) | ||
pkg.push(pkg_name, registry='s3://test-bucket', dest=dest_fn, force=True) | ||
|
||
dest_fn.assert_called_once_with(lk, pkg[lk], mock.sentinel.top_hash) | ||
push_manifest_mock.assert_called_once_with(pkg_name, mock.sentinel.top_hash, ANY) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might make sense to add specific subclass of
QuiltException
for that.