Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement semi-atomic push #2689

Merged
merged 8 commits into from
Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions api/python/quilt3/packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from collections import deque
from multiprocessing import Pool

import botocore.exceptions
import jsonlines
from tqdm import tqdm

Expand Down Expand Up @@ -375,11 +376,10 @@ def object_pairs_hook(items):
class Package:
""" In-memory representation of a package """

_origin = None

def __init__(self):
self._children = {}
self._meta = {'version': 'v0'}
self._origin = None

@ApiTelemetry("package.__repr__")
def __repr__(self, max_lines=20):
Expand Down Expand Up @@ -604,7 +604,9 @@ def download_manifest(dst):
stack.callback(os.unlink, local_pkg_manifest)
download_manifest(local_pkg_manifest)

return cls._from_path(local_pkg_manifest)
pkg = cls._from_path(local_pkg_manifest)
pkg._origin = PackageRevInfo(str(registry.base), name, top_hash)
return pkg

@classmethod
def _from_path(cls, path):
Expand Down Expand Up @@ -1039,8 +1041,7 @@ def _build(self, name, registry, message):
def _push_manifest(self, name, registry, top_hash):
manifest = io.BytesIO()
self._dump(manifest)
self._timestamp = registry.push_manifest(name, top_hash, manifest.getvalue())
self._origin = PackageRevInfo(str(registry.base), name, top_hash)
registry.push_manifest(name, top_hash, manifest.getvalue())

@ApiTelemetry("package.dump")
def dump(self, writable_file):
Expand Down Expand Up @@ -1290,7 +1291,7 @@ def _get_top_hash_parts(cls, meta, entries):

@ApiTelemetry("package.push")
@_fix_docstring(workflow=_WORKFLOW_PARAM_DOCSTRING)
def push(self, name, registry=None, dest=None, message=None, selector_fn=None, *, workflow=...):
def push(self, name, registry=None, dest=None, message=None, selector_fn=None, *, workflow=..., force=False):
"""
Copies objects to path, then creates a new package that points to those objects.
Copies each object in this package to path according to logical key structure,
Expand All @@ -1313,6 +1314,9 @@ def push(self, name, registry=None, dest=None, message=None, selector_fn=None, *
If `selector_fn('entry_1', pkg["entry_1"]) == True`,
`new_pkg["entry_1"] = ["s3://bucket/prefix/entry_1.json"]`

By default, push will not overwrite an existing package if its top hash does not match
the parent hash of the package being pushed. Use `force=True` to skip the check.

Args:
name: name for package in registry
dest: where to copy the objects in the package
Expand All @@ -1326,13 +1330,14 @@ def push(self, name, registry=None, dest=None, message=None, selector_fn=None, *
are spread over multiple buckets and you add a single local file, you can use selector_fn to
only push the local file to s3 (instead of pushing all data to the destination bucket).
%(workflow)s
force: skip the top hash check and overwrite any existing package

Returns:
A new package that points to the copied objects.
"""
return self._push(name, registry, dest, message, selector_fn, workflow=workflow, print_info=True)
return self._push(name, registry, dest, message, selector_fn, workflow=workflow, print_info=True, force=force)

def _push(self, name, registry=None, dest=None, message=None, selector_fn=None, *, workflow, print_info):
def _push(self, name, registry=None, dest=None, message=None, selector_fn=None, *, workflow, print_info, force):
if selector_fn is None:
def selector_fn(*args):
return True
Expand Down Expand Up @@ -1391,12 +1396,34 @@ def dest_fn(lk, *args, **kwargs):
registry = get_package_registry(registry)
self._validate_with_workflow(registry=registry, workflow=workflow, name=name, message=message)

def check_latest_hash():
if force:
return

try:
latest_hash = get_bytes(registry.pointer_latest_pk(name)).decode()
except botocore.exceptions.ClientError as ex:
if ex.response['Error']['Code'] == 'NoSuchKey':
# Expected
return
raise

if self._origin is None or latest_hash != self._origin.top_hash:
raise QuiltException(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might make sense to add specific subclass of QuiltException for that.

f"Package with an unexpected hash {latest_hash} already exists at the destination; "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might make sense to specify both expected and actual hashes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

"use force=True to overwrite"
)

# Check the top hash and fail early if it's unexpected.
check_latest_hash()

self._fix_sha256()

pkg = self.__class__()
pkg._meta = self._meta
pkg._set_commit_message(message)
top_hash = self._calculate_top_hash(pkg._meta, self.walk())
pkg._origin = PackageRevInfo(str(registry.base), name, top_hash)

# Since all that is modified is physical keys, pkg will have the same top hash
file_list = []
Expand Down Expand Up @@ -1446,6 +1473,9 @@ def physical_key_is_temp_file(pk):
for lk in temp_file_logical_keys:
self._set(lk, pkg[lk])

# Check top hash again just before pushing, to minimize the race condition.
check_latest_hash()

pkg._push_manifest(name, registry, top_hash)

if print_info:
Expand Down
89 changes: 73 additions & 16 deletions api/python/tests/integration/test_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections import Counter
from contextlib import redirect_stderr
from datetime import datetime
from functools import partial
from io import BytesIO
from pathlib import Path
from unittest import mock
Expand Down Expand Up @@ -89,6 +90,17 @@ def setup_s3_stubber_resolve_pointer(self, pkg_registry, pkg_name, *, pointer, t
}
)

def setup_s3_stubber_resolve_pointer_not_found(self, pkg_registry, pkg_name, *, pointer):
self.s3_stubber.add_client_error(
method='get_object',
service_error_code='NoSuchKey',
http_status_code=404,
expected_params={
'Bucket': pkg_registry.root.bucket,
'Key': pkg_registry.pointer_pk(pkg_name, pointer).path,
}
)

def setup_s3_stubber_delete_pointer(self, pkg_registry, pkg_name, *, pointer):
self.s3_stubber.add_response(
method='delete_object',
Expand Down Expand Up @@ -462,7 +474,7 @@ def add_pkg_file(pkg, lk, filename, data, *, version):
)
with patch('time.time', return_value=timestamp1), \
patch('quilt3.data_transfer.MAX_CONCURRENCY', 1):
remote_pkg = new_pkg.push(pkg_name, registry)
remote_pkg = new_pkg.push(pkg_name, registry, force=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need force=True here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because otherwise, it makes extra S3 calls, so I'd need to mock those. I guess setup_s3_stubber_push_manifest could be updated to mock them by default, but force=True plus a new unit test for push conflicts seemed simpler.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but force=True plus a new unit test for push conflicts seemed simpler

Well, yes, it's simpler, but are we going to use force=True in every new test?


# Modify one file, and check that only that file gets uploaded.
add_pkg_file(remote_pkg, 'foo2', 'bar3', '!!!', version='v2')
Expand All @@ -480,7 +492,7 @@ def add_pkg_file(pkg, lk, filename, data, *, version):
stderr = io.StringIO()

with redirect_stderr(stderr), patch('quilt3.packages.DISABLE_TQDM', True):
remote_pkg.push(pkg_name, registry)
remote_pkg.push(pkg_name, registry, force=True)
assert not stderr.getvalue()

def test_package_deserialize(self):
Expand Down Expand Up @@ -717,7 +729,7 @@ def test_set_package_entry_as_object(self):
# Test that push cleans up the temporary files, if and only if the serialization_location was not set
with patch('quilt3.Package._push_manifest'), \
patch('quilt3.packages.copy_file_list', _mock_copy_file_list):
pkg.push('Quilt/test_pkg_name', 's3://test-bucket')
pkg.push('Quilt/test_pkg_name', 's3://test-bucket', force=True)

for lk in ["mydataframe1.parquet", "mydataframe2.csv", "mydataframe3.tsv"]:
file_path = pkg[lk].physical_key.path
Expand Down Expand Up @@ -1059,19 +1071,61 @@ def test_push_restrictions(self):

# disallow pushing not to the top level of a remote S3 registry
with pytest.raises(QuiltException):
p.push('Quilt/Test', 's3://test-bucket/foo/bar')
p.push('Quilt/Test', 's3://test-bucket/foo/bar', force=True)

# disallow pushing to the local filesystem (use install instead)
with pytest.raises(QuiltException):
p.push('Quilt/Test', './')
p.push('Quilt/Test', './', force=True)

# disallow pushing the package manifest to remote but package data to local
with pytest.raises(QuiltException):
p.push('Quilt/Test', 's3://test-bucket', dest='./')
p.push('Quilt/Test', 's3://test-bucket', dest='./', force=True)

# disallow pushing the pacakge manifest to remote but package data to a different remote
with pytest.raises(QuiltException):
p.push('Quilt/Test', 's3://test-bucket', dest='s3://other-test-bucket')
p.push('Quilt/Test', 's3://test-bucket', dest='s3://other-test-bucket', force=True)

@patch('quilt3.workflows.validate', mock.MagicMock(return_value=None))
def test_push_conflicts(self):
registry = 's3://test-bucket'
pkg_registry = self.S3PackageRegistryDefault(PhysicalKey.from_url(registry))
pkg_name = 'Quilt/test'

pkg = Package()

self.patch_s3_registry('shorten_top_hash', return_value='123456')

with patch('quilt3.packages.copy_file_list', _mock_copy_file_list), \
patch('quilt3.Package._push_manifest'):
# Remote package does not yet exist: push succeeds.

for _ in range(2):
self.setup_s3_stubber_resolve_pointer_not_found(
pkg_registry, pkg_name, pointer='latest'
)

pkg2 = pkg.push('Quilt/test', 's3://test-bucket')

# Remote package exists, but has the parent hash: push succeeds.

pkg2.set('foo', b'123')
pkg2.build('Quilt/test')

for _ in range(2):
self.setup_s3_stubber_resolve_pointer(
pkg_registry, pkg_name, pointer='latest', top_hash=pkg.top_hash
)

pkg2.push('Quilt/test', 's3://test-bucket')

# Remote package exists and the hash does not match: push fails.

self.setup_s3_stubber_resolve_pointer(
pkg_registry, pkg_name, pointer='latest', top_hash=pkg2.top_hash
)

with self.assertRaisesRegex(QuiltException, 'Package with an unexpected hash'):
pkg2.push('Quilt/test', 's3://test-bucket')

@patch('quilt3.workflows.validate', return_value=None)
def test_commit_message_on_push(self, mocked_workflow_validate):
Expand All @@ -1083,7 +1137,7 @@ def test_commit_message_on_push(self, mocked_workflow_validate):
with open(REMOTE_MANIFEST, encoding='utf-8') as fd:
pkg = Package.load(fd)

pkg.push('Quilt/test_pkg_name', 's3://test-bucket', message='test_message')
pkg.push('Quilt/test_pkg_name', 's3://test-bucket', message='test_message', force=True)
registry = self.S3PackageRegistryDefault(PhysicalKey.from_url('s3://test-bucket'))
message = 'test_message'
push_manifest_mock.assert_called_once_with(
Expand Down Expand Up @@ -1186,7 +1240,7 @@ def test_manifest(self):
@patch('quilt3.workflows.validate', mock.MagicMock(return_value='workflow data'))
def test_manifest_workflow(self):
self.patch_s3_registry('shorten_top_hash', return_value='7a67ff4')
for method in (Package.build, Package.push):
for method in (Package.build, partial(Package.push, force=True)):
with self.subTest(method=method):
pkg = Package()
method(pkg, 'foo/bar', registry='s3://test-bucket')
Expand Down Expand Up @@ -1610,7 +1664,7 @@ def test_workflow_validation(self, workflow_validate_mock, copy_file_list_mock):
pkg_registry = self.S3PackageRegistryDefault(PhysicalKey.from_url('s3://test-bucket'))
self.patch_s3_registry('shorten_top_hash', return_value='7a67ff4')

for method in (Package.build, Package.push):
for method in (Package.build, partial(Package.push, force=True)):
with self.subTest(method=method):
with patch('quilt3.Package._push_manifest') as push_manifest_mock:
pkg = Package().set('foo', DATA_DIR / 'foo.txt')
Expand All @@ -1625,7 +1679,7 @@ def test_workflow_validation(self, workflow_validate_mock, copy_file_list_mock):
assert pkg._workflow is mock.sentinel.returned_workflow
push_manifest_mock.assert_called_once()
workflow_validate_mock.reset_mock()
if method is Package.push:
if method is not Package.build:
copy_file_list_mock.assert_called_once()
copy_file_list_mock.reset_mock()

Expand All @@ -1649,7 +1703,7 @@ def test_workflow_validation(self, workflow_validate_mock, copy_file_list_mock):
assert pkg._workflow is mock.sentinel.returned_workflow
push_manifest_mock.assert_called_once()
workflow_validate_mock.reset_mock()
if method is Package.push:
if method is not Package.build:
copy_file_list_mock.assert_called_once()
copy_file_list_mock.reset_mock()

Expand All @@ -1660,7 +1714,7 @@ def test_push_dest_fn_non_string(self):
with self.subTest(value=val):
with pytest.raises(TypeError) as excinfo:
pkg.push('foo/bar', registry='s3://test-bucket',
dest=(lambda v: lambda *args, **kwargs: v)(val))
dest=(lambda v: lambda *args, **kwargs: v)(val), force=True)
assert 'str is expected' in str(excinfo.value)

@patch('quilt3.workflows.validate', mock.MagicMock(return_value=None))
Expand All @@ -1670,13 +1724,16 @@ def test_push_dest_fn_non_supported_uri(self):
with self.subTest(value=val):
with pytest.raises(quilt3.util.URLParseError):
pkg.push('foo/bar', registry='s3://test-bucket',
dest=(lambda v: lambda *args, **kwargs: v)(val))
dest=(lambda v: lambda *args, **kwargs: v)(val), force=True)

@patch('quilt3.workflows.validate', mock.MagicMock(return_value=None))
def test_push_dest_fn_s3_uri_with_version_id(self):
pkg = Package().set('foo', DATA_DIR / 'foo.txt')
with pytest.raises(ValueError) as excinfo:
pkg.push('foo/bar', registry='s3://test-bucket', dest=lambda *args, **kwargs: 's3://bucket/ds?versionId=v')
pkg.push(
'foo/bar', registry='s3://test-bucket',
dest=lambda *args, **kwargs: 's3://bucket/ds?versionId=v', force=True
)
assert 'URI must not include versionId' in str(excinfo.value)

@patch('quilt3.workflows.validate', mock.MagicMock(return_value=None))
Expand All @@ -1703,7 +1760,7 @@ def test_push_dest_fn(self):
)
push_manifest_mock = self.patch_s3_registry('push_manifest')
self.patch_s3_registry('shorten_top_hash', return_value='7a67ff4')
pkg.push(pkg_name, registry='s3://test-bucket', dest=dest_fn)
pkg.push(pkg_name, registry='s3://test-bucket', dest=dest_fn, force=True)

dest_fn.assert_called_once_with(lk, pkg[lk], mock.sentinel.top_hash)
push_manifest_mock.assert_called_once_with(pkg_name, mock.sentinel.top_hash, ANY)
Expand Down
6 changes: 5 additions & 1 deletion docs/api-reference/Package.md
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ __Raises__
* `KeyError`: when logical_key is not present to be deleted


## Package.push(self, name, registry=None, dest=None, message=None, selector\_fn=None, \*, workflow=Ellipsis) {#Package.push}
## Package.push(self, name, registry=None, dest=None, message=None, selector\_fn=None, \*, workflow=Ellipsis, force=False) {#Package.push}

Copies objects to path, then creates a new package that points to those objects.
Copies each object in this package to path according to logical key structure,
Expand All @@ -293,6 +293,9 @@ If `selector_fn('entry_1', pkg["entry_1"]) == False`,
If `selector_fn('entry_1', pkg["entry_1"]) == True`,
`new_pkg["entry_1"] = ["s3://bucket/prefix/entry_1.json"]`

By default, push will not overwrite an existing package if its top hash does not match
the parent hash of the package being pushed. Use `force=True` to skip the check.

__Arguments__

* __name__: name for package in registry
Expand All @@ -310,6 +313,7 @@ __Arguments__
If not specified, the default workflow will be used.
* __For details see__: https://docs.quiltdata.com/advanced-usage/workflows

* __force__: skip the top hash check and overwrite any existing package

__Returns__

Expand Down