Skip to content

Commit

Permalink
feat: add timeout parameter to Batch interface to match google-cloud-…
Browse files Browse the repository at this point in the history
…core (#10010)

* fix: increase minimum version of google-cloud-core after a required field is introduced
* feat: extend batch interface to have timeout to match new google-cloud-core interface
  • Loading branch information
crwilcox authored Dec 26, 2019
1 parent 4f7d4b1 commit ae22885
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 38 deletions.
25 changes: 17 additions & 8 deletions storage/google/cloud/storage/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __init__(self, client):
self._requests = []
self._target_objects = []

def _do_request(self, method, url, headers, data, target_object):
def _do_request(self, method, url, headers, data, target_object, timeout=None):
"""Override Connection: defer actual HTTP request.
Only allow up to ``_MAX_BATCH_SIZE`` requests to be deferred.
Expand All @@ -173,6 +173,12 @@ def _do_request(self, method, url, headers, data, target_object):
connection. Here we defer an HTTP request and complete
initialization of the object at a later time.
:type timeout: float or tuple
:param timeout: (optional) The amount of time, in seconds, to wait
for the server response. By default, the method waits indefinitely.
Can also be passed as a tuple (connect_timeout, read_timeout).
See :meth:`requests.Session.request` documentation for details.
:rtype: tuple of ``response`` (a dictionary of sorts)
and ``content`` (a string).
:returns: The HTTP response object and the content of the response.
Expand All @@ -181,7 +187,7 @@ def _do_request(self, method, url, headers, data, target_object):
raise ValueError(
"Too many deferred requests (max %d)" % self._MAX_BATCH_SIZE
)
self._requests.append((method, url, headers, data))
self._requests.append((method, url, headers, data, timeout))
result = _FutureDict()
self._target_objects.append(target_object)
if target_object is not None:
Expand All @@ -200,9 +206,12 @@ def _prepare_batch_request(self):

multi = MIMEMultipart()

for method, uri, headers, body in self._requests:
# Use timeout of last request, default to None (indefinite)
timeout = None
for method, uri, headers, body, _timeout in self._requests:
subrequest = MIMEApplicationHTTP(method, uri, headers, body)
multi.attach(subrequest)
timeout = _timeout

# The `email` package expects to deal with "native" strings
if six.PY3: # pragma: NO COVER Python3
Expand All @@ -215,7 +224,7 @@ def _prepare_batch_request(self):

# Strip off redundant header text
_, body = payload.split("\n\n", 1)
return dict(multi._headers), body
return dict(multi._headers), body, timeout

def _finish_futures(self, responses):
"""Apply all the batch responses to the futures created.
Expand All @@ -230,7 +239,7 @@ def _finish_futures(self, responses):
# until all futures have been populated.
exception_args = None

if len(self._target_objects) != len(responses):
if len(self._target_objects) != len(responses): # pragma: NO COVER
raise ValueError("Expected a response for every request.")

for target_object, subresponse in zip(self._target_objects, responses):
Expand All @@ -251,15 +260,15 @@ def finish(self):
:rtype: list of tuples
:returns: one ``(headers, payload)`` tuple per deferred request.
"""
headers, body = self._prepare_batch_request()
headers, body, timeout = self._prepare_batch_request()

url = "%s/batch/storage/v1" % self.API_BASE_URL

# Use the private ``_base_connection`` rather than the property
# ``_connection``, since the property may be this
# current batch.
response = self._client._base_connection._make_request(
"POST", url, data=body, headers=headers
"POST", url, data=body, headers=headers, timeout=timeout
)
responses = list(_unpack_batch_response(response))
self._finish_futures(responses)
Expand Down Expand Up @@ -313,7 +322,7 @@ def _unpack_batch_response(response):
parser = Parser()
message = _generate_faux_mime_message(parser, response)

if not isinstance(message._payload, list):
if not isinstance(message._payload, list): # pragma: NO COVER
raise ValueError("Bad response: not multi-part")

for subrequest in message._payload:
Expand Down
4 changes: 2 additions & 2 deletions storage/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
# 'Development Status :: 5 - Production/Stable'
release_status = "Development Status :: 5 - Production/Stable"
dependencies = [
"google-auth >= 1.2.0",
"google-cloud-core >= 1.0.3, < 2.0dev",
"google-auth >= 1.9.0, < 2.0dev",
"google-cloud-core >= 1.1.0, < 2.0dev",
"google-resumable-media >= 0.5.0, < 0.6dev",
]
extras = {}
Expand Down
6 changes: 5 additions & 1 deletion storage/tests/unit/test__http.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ def test_extra_headers(self):
}
expected_uri = conn.build_api_url("/rainbow")
http.request.assert_called_once_with(
data=req_data, headers=expected_headers, method="GET", url=expected_uri
data=req_data,
headers=expected_headers,
method="GET",
url=expected_uri,
timeout=None,
)

def test_build_api_url_no_extra_query_params(self):
Expand Down
26 changes: 18 additions & 8 deletions storage/tests/unit/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test__make_request_GET_normal(self):
# Check the queued request
self.assertEqual(len(batch._requests), 1)
request = batch._requests[0]
request_method, request_url, _, request_data = request
request_method, request_url, _, request_data, _ = request
self.assertEqual(request_method, "GET")
self.assertEqual(request_url, url)
self.assertIsNone(request_data)
Expand All @@ -174,7 +174,7 @@ def test__make_request_POST_normal(self):
http.request.assert_not_called()

request = batch._requests[0]
request_method, request_url, _, request_data = request
request_method, request_url, _, request_data, _ = request
self.assertEqual(request_method, "POST")
self.assertEqual(request_url, url)
self.assertEqual(request_data, data)
Expand All @@ -201,7 +201,7 @@ def test__make_request_PATCH_normal(self):
http.request.assert_not_called()

request = batch._requests[0]
request_method, request_url, _, request_data = request
request_method, request_url, _, request_data, _ = request
self.assertEqual(request_method, "PATCH")
self.assertEqual(request_url, url)
self.assertEqual(request_data, data)
Expand All @@ -228,7 +228,7 @@ def test__make_request_DELETE_normal(self):
# Check the queued request
self.assertEqual(len(batch._requests), 1)
request = batch._requests[0]
request_method, request_url, _, request_data = request
request_method, request_url, _, request_data, _ = request
self.assertEqual(request_method, "DELETE")
self.assertEqual(request_url, url)
self.assertIsNone(request_data)
Expand Down Expand Up @@ -340,7 +340,11 @@ def test_finish_nonempty(self):

expected_url = "{}/batch/storage/v1".format(batch.API_BASE_URL)
http.request.assert_called_once_with(
method="POST", url=expected_url, headers=mock.ANY, data=mock.ANY
method="POST",
url=expected_url,
headers=mock.ANY,
data=mock.ANY,
timeout=mock.ANY,
)

request_info = self._get_mutlipart_request(http)
Expand Down Expand Up @@ -406,7 +410,11 @@ def test_finish_nonempty_with_status_failure(self):

expected_url = "{}/batch/storage/v1".format(batch.API_BASE_URL)
http.request.assert_called_once_with(
method="POST", url=expected_url, headers=mock.ANY, data=mock.ANY
method="POST",
url=expected_url,
headers=mock.ANY,
data=mock.ANY,
timeout=mock.ANY,
)

_, request_body, _, boundary = self._get_mutlipart_request(http)
Expand Down Expand Up @@ -620,8 +628,10 @@ class _Connection(object):
def __init__(self, **kw):
self.__dict__.update(kw)

def _make_request(self, method, url, data=None, headers=None):
return self.http.request(url=url, method=method, headers=headers, data=data)
def _make_request(self, method, url, data=None, headers=None, timeout=None):
return self.http.request(
url=url, method=method, headers=headers, data=data, timeout=timeout
)


class _MockObject(object):
Expand Down
54 changes: 35 additions & 19 deletions storage/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def test_get_service_account_email_wo_project(self):
]
)
http.request.assert_called_once_with(
method="GET", url=URI, data=None, headers=mock.ANY
method="GET", url=URI, data=None, headers=mock.ANY, timeout=mock.ANY
)

def test_get_service_account_email_w_project(self):
Expand All @@ -297,7 +297,7 @@ def test_get_service_account_email_w_project(self):
]
)
http.request.assert_called_once_with(
method="GET", url=URI, data=None, headers=mock.ANY
method="GET", url=URI, data=None, headers=mock.ANY, timeout=mock.ANY
)

def test_bucket(self):
Expand Down Expand Up @@ -366,7 +366,7 @@ def test_get_bucket_with_string_miss(self):
client.get_bucket(NONESUCH)

http.request.assert_called_once_with(
method="GET", url=URI, data=mock.ANY, headers=mock.ANY
method="GET", url=URI, data=mock.ANY, headers=mock.ANY, timeout=mock.ANY
)

def test_get_bucket_with_string_hit(self):
Expand Down Expand Up @@ -396,7 +396,7 @@ def test_get_bucket_with_string_hit(self):
self.assertIsInstance(bucket, Bucket)
self.assertEqual(bucket.name, BUCKET_NAME)
http.request.assert_called_once_with(
method="GET", url=URI, data=mock.ANY, headers=mock.ANY
method="GET", url=URI, data=mock.ANY, headers=mock.ANY, timeout=mock.ANY
)

def test_get_bucket_with_object_miss(self):
Expand Down Expand Up @@ -427,7 +427,7 @@ def test_get_bucket_with_object_miss(self):
client.get_bucket(bucket_obj)

http.request.assert_called_once_with(
method="GET", url=URI, data=mock.ANY, headers=mock.ANY
method="GET", url=URI, data=mock.ANY, headers=mock.ANY, timeout=mock.ANY
)

def test_get_bucket_with_object_hit(self):
Expand Down Expand Up @@ -458,7 +458,7 @@ def test_get_bucket_with_object_hit(self):
self.assertIsInstance(bucket, Bucket)
self.assertEqual(bucket.name, bucket_name)
http.request.assert_called_once_with(
method="GET", url=URI, data=mock.ANY, headers=mock.ANY
method="GET", url=URI, data=mock.ANY, headers=mock.ANY, timeout=mock.ANY
)

def test_lookup_bucket_miss(self):
Expand All @@ -485,7 +485,7 @@ def test_lookup_bucket_miss(self):

self.assertIsNone(bucket)
http.request.assert_called_once_with(
method="GET", url=URI, data=mock.ANY, headers=mock.ANY
method="GET", url=URI, data=mock.ANY, headers=mock.ANY, timeout=mock.ANY
)

def test_lookup_bucket_hit(self):
Expand Down Expand Up @@ -514,7 +514,7 @@ def test_lookup_bucket_hit(self):
self.assertIsInstance(bucket, Bucket)
self.assertEqual(bucket.name, BUCKET_NAME)
http.request.assert_called_once_with(
method="GET", url=URI, data=mock.ANY, headers=mock.ANY
method="GET", url=URI, data=mock.ANY, headers=mock.ANY, timeout=mock.ANY
)

def test_create_bucket_w_missing_client_project(self):
Expand Down Expand Up @@ -666,7 +666,7 @@ def test_create_bucket_w_string_success(self):
self.assertEqual(bucket.name, bucket_name)
self.assertTrue(bucket.requester_pays)
http.request.assert_called_once_with(
method="POST", url=URI, data=mock.ANY, headers=mock.ANY
method="POST", url=URI, data=mock.ANY, headers=mock.ANY, timeout=mock.ANY
)
json_sent = http.request.call_args_list[0][1]["data"]
self.assertEqual(json_expected, json.loads(json_sent))
Expand Down Expand Up @@ -706,7 +706,7 @@ def test_create_bucket_w_object_success(self):
self.assertEqual(bucket.name, bucket_name)
self.assertTrue(bucket.requester_pays)
http.request.assert_called_once_with(
method="POST", url=URI, data=mock.ANY, headers=mock.ANY
method="POST", url=URI, data=mock.ANY, headers=mock.ANY, timeout=mock.ANY
)
json_sent = http.request.call_args_list[0][1]["data"]
self.assertEqual(json_expected, json.loads(json_sent))
Expand Down Expand Up @@ -848,7 +848,11 @@ def test_list_buckets_empty(self):
self.assertEqual(len(buckets), 0)

http.request.assert_called_once_with(
method="GET", url=mock.ANY, data=mock.ANY, headers=mock.ANY
method="GET",
url=mock.ANY,
data=mock.ANY,
headers=mock.ANY,
timeout=mock.ANY,
)

requested_url = http.request.mock_calls[0][2]["url"]
Expand Down Expand Up @@ -883,7 +887,11 @@ def test_list_buckets_explicit_project(self):
self.assertEqual(len(buckets), 0)

http.request.assert_called_once_with(
method="GET", url=mock.ANY, data=mock.ANY, headers=mock.ANY
method="GET",
url=mock.ANY,
data=mock.ANY,
headers=mock.ANY,
timeout=mock.ANY,
)

requested_url = http.request.mock_calls[0][2]["url"]
Expand Down Expand Up @@ -918,7 +926,11 @@ def test_list_buckets_non_empty(self):
self.assertEqual(buckets[0].name, BUCKET_NAME)

http.request.assert_called_once_with(
method="GET", url=mock.ANY, data=mock.ANY, headers=mock.ANY
method="GET",
url=mock.ANY,
data=mock.ANY,
headers=mock.ANY,
timeout=mock.ANY,
)

def test_list_buckets_all_arguments(self):
Expand Down Expand Up @@ -948,7 +960,11 @@ def test_list_buckets_all_arguments(self):
buckets = list(iterator)
self.assertEqual(buckets, [])
http.request.assert_called_once_with(
method="GET", url=mock.ANY, data=mock.ANY, headers=mock.ANY
method="GET",
url=mock.ANY,
data=mock.ANY,
headers=mock.ANY,
timeout=mock.ANY,
)

requested_url = http.request.mock_calls[0][2]["url"]
Expand Down Expand Up @@ -1077,7 +1093,7 @@ def _create_hmac_key_helper(self, explicit_project=None, user_project=None):

FULL_URI = "{}?{}".format(URI, urlencode(qs_params))
http.request.assert_called_once_with(
method="POST", url=FULL_URI, data=None, headers=mock.ANY
method="POST", url=FULL_URI, data=None, headers=mock.ANY, timeout=mock.ANY
)

def test_create_hmac_key_defaults(self):
Expand Down Expand Up @@ -1112,7 +1128,7 @@ def test_list_hmac_keys_defaults_empty(self):
]
)
http.request.assert_called_once_with(
method="GET", url=URI, data=None, headers=mock.ANY
method="GET", url=URI, data=None, headers=mock.ANY, timeout=mock.ANY
)

def test_list_hmac_keys_explicit_non_empty(self):
Expand Down Expand Up @@ -1176,7 +1192,7 @@ def test_list_hmac_keys_explicit_non_empty(self):
"userProject": USER_PROJECT,
}
http.request.assert_called_once_with(
method="GET", url=mock.ANY, data=None, headers=mock.ANY
method="GET", url=mock.ANY, data=None, headers=mock.ANY, timeout=mock.ANY
)
kwargs = http.request.mock_calls[0].kwargs
uri = kwargs["url"]
Expand Down Expand Up @@ -1223,7 +1239,7 @@ def test_get_hmac_key_metadata_wo_project(self):
]
)
http.request.assert_called_once_with(
method="GET", url=URI, data=None, headers=mock.ANY
method="GET", url=URI, data=None, headers=mock.ANY, timeout=mock.ANY
)

def test_get_hmac_key_metadata_w_project(self):
Expand Down Expand Up @@ -1273,5 +1289,5 @@ def test_get_hmac_key_metadata_w_project(self):
FULL_URI = "{}?{}".format(URI, urlencode(qs_params))

http.request.assert_called_once_with(
method="GET", url=FULL_URI, data=None, headers=mock.ANY
method="GET", url=FULL_URI, data=None, headers=mock.ANY, timeout=mock.ANY
)

0 comments on commit ae22885

Please sign in to comment.