Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Always use raw response data. (#87)" #103

Merged
merged 4 commits into from
Sep 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/latest/.buildinfo
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
config: 68666af324b279e5e7d7ada9d27ddd71
config: b911398855f7668d2aa125e82024376d
tags: 645f666f9bcd5a90fca523b33c5a78b7
40 changes: 14 additions & 26 deletions google/resumable_media/_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,36 +354,24 @@ def _process_response(self, response):
self._get_status_code,
callback=self._make_invalid,
)
headers = self._get_headers(response)
response_body = self._get_body(response)

start_byte, end_byte, total_bytes = get_range_info(
content_length = _helpers.header_required(
response, u"content-length", self._get_headers, callback=self._make_invalid
)
num_bytes = int(content_length)
_, end_byte, total_bytes = get_range_info(
response, self._get_headers, callback=self._make_invalid
)

transfer_encoding = headers.get(u"transfer-encoding")

if transfer_encoding is None:
content_length = _helpers.header_required(
response_body = self._get_body(response)
if len(response_body) != num_bytes:
self._make_invalid()
raise common.InvalidResponse(
response,
u"content-length",
self._get_headers,
callback=self._make_invalid,
u"Response is different size than content-length",
u"Expected",
num_bytes,
u"Received",
len(response_body),
)
num_bytes = int(content_length)
if len(response_body) != num_bytes:
self._make_invalid()
raise common.InvalidResponse(
response,
u"Response is different size than content-length",
u"Expected",
num_bytes,
u"Received",
len(response_body),
)
else:
# 'content-length' header not allowed with chunked encoding.
num_bytes = end_byte - start_byte + 1

# First update ``bytes_downloaded``.
self._bytes_downloaded += num_bytes
Expand Down
8 changes: 1 addition & 7 deletions google/resumable_media/requests/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@


_DEFAULT_RETRY_STRATEGY = common.RetryStrategy()
_SINGLE_GET_CHUNK_SIZE = 8192
# The number of seconds to wait to establish a connection
# (connect() call on socket). Avoid setting this to a multiple of 3 to not
# Align with TCP Retransmission timing. (typically 2.5-3s)
Expand Down Expand Up @@ -76,12 +75,7 @@ def _get_body(response):
Returns:
bytes: The body of the ``response``.
"""
if response._content is False:
response._content = b"".join(
response.raw.stream(_SINGLE_GET_CHUNK_SIZE, decode_content=False)
)
response._content_consumed = True
return response._content
return response.content


def http_request(
Expand Down
97 changes: 77 additions & 20 deletions google/resumable_media/requests/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
import hashlib
import logging

import urllib3.response

from google.resumable_media import _download
from google.resumable_media import common
from google.resumable_media.requests import _helpers


_LOGGER = logging.getLogger(__name__)
_SINGLE_GET_CHUNK_SIZE = 8192
_HASH_HEADER = u"x-goog-hash"
_MISSING_MD5 = u"""\
No MD5 checksum was returned from the service while downloading {}
Expand Down Expand Up @@ -113,13 +116,13 @@ def _write_to_stream(self, response):
with response:
# NOTE: This might "donate" ``md5_hash`` to the decoder and replace
# it with a ``_DoNothingHash``.
body_iter = response.raw.stream(
_helpers._SINGLE_GET_CHUNK_SIZE, decode_content=False
local_hash = _add_decoder(response.raw, md5_hash)
body_iter = response.iter_content(
chunk_size=_SINGLE_GET_CHUNK_SIZE, decode_unicode=False
)
for chunk in body_iter:
self._stream.write(chunk)
md5_hash.update(chunk)
response._content_consumed = True
local_hash.update(chunk)

if expected_md5_hash is None:
return
Expand Down Expand Up @@ -155,22 +158,22 @@ def consume(self, transport):
"""
method, url, payload, headers = self._prepare_request()
# NOTE: We assume "payload is None" but pass it along anyway.
response = _helpers.http_request(
transport,
method,
url,
data=payload,
headers=headers,
retry_strategy=self._retry_strategy,
stream=True,
)
request_kwargs = {
u"data": payload,
u"headers": headers,
u"retry_strategy": self._retry_strategy,
}
if self._stream is not None:
request_kwargs[u"stream"] = True

self._process_response(response)
result = _helpers.http_request(transport, method, url, **request_kwargs)

self._process_response(result)

if self._stream is not None:
self._write_to_stream(response)
self._write_to_stream(result)

return response
return result


class ChunkedDownload(_helpers.RequestsMixin, _download.ChunkedDownload):
Expand Down Expand Up @@ -216,17 +219,16 @@ def consume_next_chunk(self, transport):
"""
method, url, payload, headers = self._prepare_request()
# NOTE: We assume "payload is None" but pass it along anyway.
response = _helpers.http_request(
result = _helpers.http_request(
transport,
method,
url,
data=payload,
headers=headers,
retry_strategy=self._retry_strategy,
stream=True,
)
self._process_response(response)
return response
self._process_response(result)
return result


def _parse_md5_header(header_value, response):
Expand Down Expand Up @@ -294,3 +296,58 @@ def update(self, unused_chunk):
Args:
unused_chunk (bytes): A chunk of data.
"""


def _add_decoder(response_raw, md5_hash):
"""Patch the ``_decoder`` on a ``urllib3`` response.

This is so that we can intercept the compressed bytes before they are
decoded.

Only patches if the content encoding is ``gzip``.

Args:
response_raw (urllib3.response.HTTPResponse): The raw response for
an HTTP request.
md5_hash (Union[_DoNothingHash, hashlib.md5]): A hash function which
will get updated when it encounters compressed bytes.

Returns:
Union[_DoNothingHash, hashlib.md5]: Either the original ``md5_hash``
if ``_decoder`` is not patched. Otherwise, returns a ``_DoNothingHash``
since the caller will no longer need to hash to decoded bytes.
"""
encoding = response_raw.headers.get(u"content-encoding", u"").lower()
if encoding != u"gzip":
return md5_hash

response_raw._decoder = _GzipDecoder(md5_hash)
return _DoNothingHash()


class _GzipDecoder(urllib3.response.GzipDecoder):
"""Custom subclass of ``urllib3`` decoder for ``gzip``-ed bytes.

Allows an MD5 hash function to see the compressed bytes before they are
decoded. This way the hash of the compressed value can be computed.

Args:
md5_hash (Union[_DoNothingHash, hashlib.md5]): A hash function which
will get updated when it encounters compressed bytes.
"""

def __init__(self, md5_hash):
super(_GzipDecoder, self).__init__()
self._md5_hash = md5_hash

def decompress(self, data):
"""Decompress the bytes.

Args:
data (bytes): The compressed bytes to be decompressed.

Returns:
bytes: The decompressed bytes from ``data``.
"""
self._md5_hash.update(data)
return super(_GzipDecoder, self).decompress(data)
24 changes: 10 additions & 14 deletions tests/system/requests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@
from six.moves import http_client

from google import resumable_media
from google.resumable_media import requests as resumable_requests
from google.resumable_media.requests import download as download_mod
from google.resumable_media.requests import _helpers
import google.resumable_media.requests as resumable_requests
import google.resumable_media.requests.download as download_mod
from tests.system import utils


Expand Down Expand Up @@ -61,11 +60,12 @@
{
u"path": os.path.realpath(os.path.join(DATA_DIR, u"file.txt")),
u"content_type": PLAIN_TEXT,
u"checksum": u"XHSHAr/SpIeZtZbjgQ4nGw==",
u"checksum": u"KHRs/+ZSrc/FuuR4qz/PZQ==",
u"slices": (),
},
{
u"path": os.path.realpath(os.path.join(DATA_DIR, u"gzipped.txt.gz")),
u"uncompressed": os.path.realpath(os.path.join(DATA_DIR, u"gzipped.txt")),
u"content_type": PLAIN_TEXT,
u"checksum": u"KHRs/+ZSrc/FuuR4qz/PZQ==",
u"slices": (),
Expand Down Expand Up @@ -126,13 +126,13 @@ def _get_contents_for_upload(info):


def _get_contents(info):
full_path = info[u"path"]
full_path = info.get(u"uncompressed", info[u"path"])
with open(full_path, u"rb") as file_obj:
return file_obj.read()


def _get_blob_name(info):
full_path = info[u"path"]
full_path = info.get(u"uncompressed", info[u"path"])
return os.path.basename(full_path)


Expand Down Expand Up @@ -179,12 +179,6 @@ def check_tombstoned(download, transport):
assert exc_info.match(u"Download has finished.")


def read_raw_content(response):
return b"".join(
response.raw.stream(_helpers._SINGLE_GET_CHUNK_SIZE, decode_content=False)
)


def test_download_full(add_files, authorized_transport):
for info in ALL_FILES:
actual_contents = _get_contents(info)
Expand All @@ -196,7 +190,7 @@ def test_download_full(add_files, authorized_transport):
# Consume the resource.
response = download.consume(authorized_transport)
assert response.status_code == http_client.OK
assert read_raw_content(response) == actual_contents
assert response.content == actual_contents
check_tombstoned(download, authorized_transport)


Expand All @@ -221,6 +215,7 @@ def test_download_to_stream(add_files, authorized_transport):
check_tombstoned(download, authorized_transport)


@pytest.mark.xfail # See: #76
def test_corrupt_download(add_files, corrupting_transport):
for info in ALL_FILES:
blob_name = _get_blob_name(info)
Expand Down Expand Up @@ -396,7 +391,8 @@ def consume_chunks(download, authorized_transport, total_bytes, actual_contents)
return num_responses, response


def test_chunked_download_full(add_files, authorized_transport):
@pytest.mark.xfail # See issue #56
def test_chunked_download(add_files, authorized_transport):
for info in ALL_FILES:
actual_contents = _get_contents(info)
blob_name = _get_blob_name(info)
Expand Down
16 changes: 3 additions & 13 deletions tests/unit/requests/test__helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,12 @@ def test__get_status_code(self):

def test__get_headers(self):
headers = {u"fruit": u"apple"}
response = mock.Mock(headers=headers, spec=["headers"])
response = mock.Mock(headers=headers, spec=[u"headers"])
assert headers == _helpers.RequestsMixin._get_headers(response)

def test__get_body_wo_content_consumed(self):
def test__get_body(self):
body = b"This is the payload."
raw = mock.Mock(spec=["stream"])
raw.stream.return_value = iter([body])
response = mock.Mock(raw=raw, _content=False, spec=["raw", "_content"])
assert body == _helpers.RequestsMixin._get_body(response)
raw.stream.assert_called_once_with(
_helpers._SINGLE_GET_CHUNK_SIZE, decode_content=False
)

def test__get_body_w_content_consumed(self):
body = b"This is the payload."
response = mock.Mock(_content=body, spec=["_content"])
response = mock.Mock(content=body, spec=[u"content"])
assert body == _helpers.RequestsMixin._get_body(response)


Expand Down
Loading