From cd12344ebf2df447fc4f9a2f562e87242daa0c40 Mon Sep 17 00:00:00 2001 From: iscai-msft Date: Tue, 13 Apr 2021 17:41:04 -0400 Subject: [PATCH] temp --- .../azure-core/azure/core/rest/_rest_py3.py | 54 ++++++++++++------- .../test_rest/test_async_http_response.py | 24 ++++++--- .../tests/test_rest/test_http_response.py | 1 + .../tests/test_rest/test_stream_responses.py | 1 - 4 files changed, 53 insertions(+), 27 deletions(-) diff --git a/sdk/core/azure-core/azure/core/rest/_rest_py3.py b/sdk/core/azure-core/azure/core/rest/_rest_py3.py index 0fe9317b166f..bbf475d7caf8 100644 --- a/sdk/core/azure-core/azure/core/rest/_rest_py3.py +++ b/sdk/core/azure-core/azure/core/rest/_rest_py3.py @@ -545,6 +545,12 @@ def stream_download(self, pipeline=None): :rtype: iterator[bytes] """ + def _validate_streaming_access(self) -> None: + if self.is_closed: + raise TypeError("Can not iterate over stream, it is closed.") + if self.is_stream_consumed: + raise TypeError("Can not iterate over stream, it has been fully consumed") + class HttpResponse(_HttpResponseBase): @property @@ -567,9 +573,16 @@ def read(self) -> bytes: Read the response's bytes. """ - if not hasattr(self, "_content"): - self._content = self._internal_response.internal_response.read() - return self._content + try: + return self._content + except AttributeError: + self._validate_streaming_access() + self._content = ( + self._internal_response.body() or + b"".join(self.iter_raw()) + ) + self._close_stream() + return self._content def iter_bytes(self, chunk_size: int = None) -> Iterator[bytes]: """Iterate over the bytes in the response stream @@ -596,20 +609,20 @@ def iter_lines(self, chunk_size: int = None) -> Iterator[str]: for line in lines: yield line + def _close_stream(self) -> None: + self.is_stream_consumed = True + self.close() + def iter_raw(self, chunk_size: int = None) -> Iterator[bytes]: """Iterate over the raw response bytes """ - if self.is_closed: - raise TypeError("Can not iterate over stream, it is closed.") - if self.is_stream_consumed: - raise TypeError("Can not iterate over stream, it has been fully consumed") + self._validate_streaming_access() stream_download = self._internal_response.stream_download(None, chunk_size=chunk_size) for raw_bytes in stream_download: self._num_bytes_downloaded += len(raw_bytes) yield raw_bytes - self.is_stream_consumed = True - self.close() # close after iterating through everything + self._close_stream() class AsyncHttpResponse(_HttpResponseBase): @@ -621,14 +634,23 @@ def content(self) -> bytes: except AttributeError: raise TypeError("You have not read in the response's bytes yet. Call response.read() first.") + async def _close_stream(self) -> None: + self.is_stream_consumed = True + await self.close() + async def read(self) -> bytes: """ Read the response's bytes. """ - if not hasattr(self, "_content"): - self._content = await self._internal_response.internal_response.read() - return self._content + try: + return self._content + except AttributeError: + self._validate_streaming_access() + await self._internal_response.load_body() + self._content = self._internal_response._body + await self._close_stream() + return self._content async def iter_bytes(self, chunk_size: int = None) -> Iterator[bytes]: """Iterate over the bytes in the response stream @@ -658,17 +680,13 @@ async def iter_lines(self, chunk_size: int = None) -> Iterator[str]: async def iter_raw(self, chunk_size: int = None) -> Iterator[bytes]: """Iterate over the raw response bytes """ - if self.is_closed: - raise TypeError("Can not iterate over stream, it is closed.") - if self.is_stream_consumed: - raise TypeError("Can not iterate over stream, it has been fully consumed") + self._validate_streaming_access() stream_download = self._internal_response.stream_download(None, chunk_size=chunk_size) async for raw_bytes in stream_download: self._num_bytes_downloaded += len(raw_bytes) yield raw_bytes - self.is_stream_consumed = True - await self.close() # close after iterating through everything + await self._close_stream() async def close(self) -> None: self.is_closed = True diff --git a/sdk/core/azure-core/tests/test_rest/test_async_http_response.py b/sdk/core/azure-core/tests/test_rest/test_async_http_response.py index 5f27ec968203..63d7941481ba 100644 --- a/sdk/core/azure-core/tests/test_rest/test_async_http_response.py +++ b/sdk/core/azure-core/tests/test_rest/test_async_http_response.py @@ -48,7 +48,8 @@ async def test_rest_response(): assert response.status_code == 200 assert response.reason == "OK" - await response.read() + content = await response.read() + assert content == b"Hello, world!" assert response.text == "Hello, world!" assert response.request.method == "GET" assert response.request.url == "https://example.org" @@ -61,7 +62,8 @@ async def test_rest_response_content(): assert response.status_code == 200 assert response.reason == "OK" - await response.read() + content = await response.read() + assert content == b"Hello, world!" assert response.text == "Hello, world!" response.raise_for_status() @@ -75,7 +77,8 @@ async def test_rest_response_text(): assert response.status_code == 200 assert response.reason == "OK" - await response.read() + content = await response.read() + assert content == b"Hello, world!" assert response.text == "Hello, world!" assert response.headers == { "Content-Length": "13", @@ -90,7 +93,8 @@ async def test_rest_response_html(): assert response.status_code == 200 assert response.reason == "OK" - await response.read() + content = await response.read() + assert content == b"Hello, world!" assert response.text == "Hello, world!" response.raise_for_status() @@ -131,7 +135,8 @@ async def test_rest_response_content_type_encoding(): content=content, headers=headers, ) - await response.read() + content == await response.read() + assert content == b'Latin 1: \xff' assert response.text == "Latin 1: ÿ" assert response.encoding == "latin-1" @@ -183,7 +188,8 @@ async def test_rest_response_no_charset_with_ascii_content(): ) assert response.status_code == 200 assert response.encoding is None - await response.read() + content = await response.read() + assert content == b"Hello, world!" assert response.text == "Hello, world!" @@ -200,7 +206,8 @@ async def test_rest_response_no_charset_with_iso_8859_1_content(): content=content, headers=headers, ) - await response.read() + content = await response.read() + assert content == b'Accented: \xd6sterreich' assert response.text == "Accented: Österreich" assert response.encoding is None @@ -215,7 +222,8 @@ async def test_rest_response_set_explicit_encoding(): headers=headers, ) response.encoding = "latin-1" - await response.read() + content = await response.read() + assert content == b'Latin 1: \xff' assert response.text == "Latin 1: ÿ" assert response.encoding == "latin-1" diff --git a/sdk/core/azure-core/tests/test_rest/test_http_response.py b/sdk/core/azure-core/tests/test_rest/test_http_response.py index b9b49ea484bd..15c5c3b5733b 100644 --- a/sdk/core/azure-core/tests/test_rest/test_http_response.py +++ b/sdk/core/azure-core/tests/test_rest/test_http_response.py @@ -111,6 +111,7 @@ def test_rest_response_repr(): content=b"Hello, world!", headers=headers ) + response.read() assert repr(response) == "" def test_rest_response_content_type_encoding(): diff --git a/sdk/core/azure-core/tests/test_rest/test_stream_responses.py b/sdk/core/azure-core/tests/test_rest/test_stream_responses.py index 555b73399922..0a8e42c03030 100644 --- a/sdk/core/azure-core/tests/test_rest/test_stream_responses.py +++ b/sdk/core/azure-core/tests/test_rest/test_stream_responses.py @@ -154,7 +154,6 @@ def test_rest_sync_streaming_response(): assert response.content == file_bytes assert response.is_closed - def test_rest_cannot_read_after_stream_consumed(): response = _create_http_response(url="https://httpbin.org/image/jpeg")