From 60e11ba16d40110f9aec29e33751c4f901123594 Mon Sep 17 00:00:00 2001 From: Xiang Yan Date: Wed, 3 Nov 2021 08:41:17 -0700 Subject: [PATCH] raise IncompleteReadError if only receive partial response (#20888) * raise IncompleteReadError if only receive partial response * update * Update CHANGELOG.md * update * update * update * update * update * update * update * address review feedback * update * update * update * update * Update exceptions.py --- sdk/core/azure-core/CHANGELOG.md | 4 +++ sdk/core/azure-core/azure/core/exceptions.py | 4 +++ .../azure/core/pipeline/transport/_aiohttp.py | 15 +++++++-- .../pipeline/transport/_requests_asyncio.py | 28 +++++++++++++--- .../pipeline/transport/_requests_basic.py | 32 +++++++++++++++++-- .../core/pipeline/transport/_requests_trio.py | 28 +++++++++++++--- .../test_content_length_checking_async.py | 28 ++++++++++++++++ .../tests/test_content_length_checking.py | 27 ++++++++++++++++ .../coretestserver/test_routes/errors.py | 7 ++++ 9 files changed, 161 insertions(+), 12 deletions(-) create mode 100644 sdk/core/azure-core/tests/async_tests/test_content_length_checking_async.py create mode 100644 sdk/core/azure-core/tests/test_content_length_checking.py diff --git a/sdk/core/azure-core/CHANGELOG.md b/sdk/core/azure-core/CHANGELOG.md index 6b1bbf9ba9a2..cb381f2d9ad7 100644 --- a/sdk/core/azure-core/CHANGELOG.md +++ b/sdk/core/azure-core/CHANGELOG.md @@ -5,6 +5,7 @@ ### Features Added - add kwargs to the methods for `iter_raw` and `iter_bytes` #21529 +- Added new error type `IncompleteReadError` which is raised if peer closes the connection before we have received the complete message body. ### Breaking Changes @@ -12,6 +13,9 @@ ### Bugs Fixed +- The `Content-Length` header in a http response is strictly checked against the actual number of bytes in the body, + rather than silently truncating data in case the underlying tcp connection is closed prematurely. + (thanks to @jochen-ott-by for the contribution) #20412 - UnboundLocalError when SansIOHTTPPolicy handles an exception #15222 ### Other Changes diff --git a/sdk/core/azure-core/azure/core/exceptions.py b/sdk/core/azure-core/azure/core/exceptions.py index 59d2f4c1ff61..34a21dbab723 100644 --- a/sdk/core/azure-core/azure/core/exceptions.py +++ b/sdk/core/azure-core/azure/core/exceptions.py @@ -338,6 +338,10 @@ class DecodeError(HttpResponseError): """Error raised during response deserialization.""" +class IncompleteReadError(DecodeError): + """Error raised if peer closes the connection before we have received the complete message body.""" + + class ResourceExistsError(HttpResponseError): """An error response with status code 4xx. This will not be raised directly by the Azure core pipeline.""" diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py index 0ade4b7a6f4e..2d18e6bd91c3 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py @@ -36,7 +36,7 @@ from multidict import CIMultiDict from azure.core.configuration import ConnectionConfiguration -from azure.core.exceptions import ServiceRequestError, ServiceResponseError +from azure.core.exceptions import ServiceRequestError, ServiceResponseError, IncompleteReadError from azure.core.pipeline import Pipeline from ._base import HttpRequest @@ -300,6 +300,12 @@ async def __anext__(self): except _ResponseStopIteration: internal_response.close() raise StopAsyncIteration() + except aiohttp.client_exceptions.ClientPayloadError as err: + # This is the case that server closes connection before we finish the reading. aiohttp library + # raises ClientPayloadError. + _LOGGER.warning("Incomplete download: %s", err) + internal_response.close() + raise IncompleteReadError(err, error=err) except Exception as err: _LOGGER.warning("Unable to stream download: %s", err) internal_response.close() @@ -384,7 +390,12 @@ def text(self, encoding: Optional[str] = None) -> str: async def load_body(self) -> None: """Load in memory the body, so it could be accessible from sync methods.""" - self._content = await self.internal_response.read() + try: + self._content = await self.internal_response.read() + except aiohttp.client_exceptions.ClientPayloadError as err: + # This is the case that server closes connection before we finish the reading. aiohttp library + # raises ClientPayloadError. + raise IncompleteReadError(err, error=err) def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]: """Generator for streaming response body data. diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py index b5d61aeff474..d0adf93efca0 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py @@ -28,7 +28,7 @@ import functools import logging from typing import ( - Any, Union, Optional, AsyncIterator as AsyncIteratorType, TYPE_CHECKING, overload + Any, Optional, AsyncIterator as AsyncIteratorType, TYPE_CHECKING, overload ) import urllib3 # type: ignore @@ -36,7 +36,9 @@ from azure.core.exceptions import ( ServiceRequestError, - ServiceResponseError + ServiceResponseError, + IncompleteReadError, + HttpResponseError, ) from azure.core.pipeline import Pipeline from ._base import HttpRequest @@ -44,7 +46,7 @@ AsyncHttpResponse, _ResponseStopIteration, _iterate_response_content) -from ._requests_basic import RequestsTransportResponse, _read_raw_stream +from ._requests_basic import RequestsTransportResponse, _read_raw_stream, AzureErrorUnion from ._base_requests_async import RequestsAsyncTransportBase from .._tools import is_rest as _is_rest from .._tools_async import handle_no_stream_rest_response as _handle_no_stream_rest_response @@ -134,7 +136,7 @@ async def send(self, request, **kwargs): # pylint:disable=invalid-overridden-me self.open() loop = kwargs.get("loop", _get_running_loop()) response = None - error = None # type: Optional[Union[ServiceRequestError, ServiceResponseError]] + error = None # type: Optional[AzureErrorUnion] data_to_send = await self._retrieve_request_data(request) try: response = await loop.run_in_executor( @@ -151,6 +153,7 @@ async def send(self, request, **kwargs): # pylint:disable=invalid-overridden-me cert=kwargs.pop('connection_cert', self.connection_config.cert), allow_redirects=False, **kwargs)) + response.raw.enforce_content_length = True except urllib3.exceptions.NewConnectionError as err: error = ServiceRequestError(err, error=err) @@ -161,6 +164,14 @@ async def send(self, request, **kwargs): # pylint:disable=invalid-overridden-me error = ServiceResponseError(err, error=err) else: error = ServiceRequestError(err, error=err) + except requests.exceptions.ChunkedEncodingError as err: + msg = err.__str__() + if 'IncompleteRead' in msg: + _LOGGER.warning("Incomplete download: %s", err) + error = IncompleteReadError(err, error=err) + else: + _LOGGER.warning("Unable to stream download: %s", err) + error = HttpResponseError(err, error=err) except requests.RequestException as err: error = ServiceRequestError(err, error=err) @@ -223,6 +234,15 @@ async def __anext__(self): raise StopAsyncIteration() except requests.exceptions.StreamConsumedError: raise + except requests.exceptions.ChunkedEncodingError as err: + msg = err.__str__() + if 'IncompleteRead' in msg: + _LOGGER.warning("Incomplete download: %s", err) + internal_response.close() + raise IncompleteReadError(err, error=err) + _LOGGER.warning("Unable to stream download: %s", err) + internal_response.close() + raise HttpResponseError(err, error=err) except Exception as err: _LOGGER.warning("Unable to stream download: %s", err) internal_response.close() diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py index 728ae0ad8566..ab9807e7bdc0 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py @@ -36,7 +36,9 @@ from azure.core.configuration import ConnectionConfiguration from azure.core.exceptions import ( ServiceRequestError, - ServiceResponseError + ServiceResponseError, + IncompleteReadError, + HttpResponseError, ) from . import HttpRequest # pylint: disable=unused-import @@ -51,6 +53,13 @@ if TYPE_CHECKING: from ...rest import HttpRequest as RestHttpRequest, HttpResponse as RestHttpResponse +AzureErrorUnion = Union[ + ServiceRequestError, + ServiceResponseError, + IncompleteReadError, + HttpResponseError, +] + PipelineType = TypeVar("PipelineType") _LOGGER = logging.getLogger(__name__) @@ -79,6 +88,7 @@ def _read_raw_stream(response, chunk_size=1): # https://github.com/psf/requests/blob/master/requests/models.py#L774 response._content_consumed = True # pylint: disable=protected-access + class _RequestsTransportResponseBase(_HttpResponseBase): """Base class for accessing response data. @@ -164,6 +174,15 @@ def __next__(self): raise StopIteration() except requests.exceptions.StreamConsumedError: raise + except requests.exceptions.ChunkedEncodingError as err: + msg = err.__str__() + if 'IncompleteRead' in msg: + _LOGGER.warning("Incomplete download: %s", err) + internal_response.close() + raise IncompleteReadError(err, error=err) + _LOGGER.warning("Unable to stream download: %s", err) + internal_response.close() + raise HttpResponseError(err, error=err) except Exception as err: _LOGGER.warning("Unable to stream download: %s", err) internal_response.close() @@ -289,7 +308,7 @@ def send(self, request, **kwargs): # type: ignore """ self.open() response = None - error = None # type: Optional[Union[ServiceRequestError, ServiceResponseError]] + error = None # type: Optional[AzureErrorUnion] try: connection_timeout = kwargs.pop('connection_timeout', self.connection_config.timeout) @@ -313,6 +332,7 @@ def send(self, request, **kwargs): # type: ignore cert=kwargs.pop('connection_cert', self.connection_config.cert), allow_redirects=False, **kwargs) + response.raw.enforce_content_length = True except (urllib3.exceptions.NewConnectionError, urllib3.exceptions.ConnectTimeoutError) as err: error = ServiceRequestError(err, error=err) @@ -323,6 +343,14 @@ def send(self, request, **kwargs): # type: ignore error = ServiceResponseError(err, error=err) else: error = ServiceRequestError(err, error=err) + except requests.exceptions.ChunkedEncodingError as err: + msg = err.__str__() + if 'IncompleteRead' in msg: + _LOGGER.warning("Incomplete download: %s", err) + error = IncompleteReadError(err, error=err) + else: + _LOGGER.warning("Unable to stream download: %s", err) + error = HttpResponseError(err, error=err) except requests.RequestException as err: error = ServiceRequestError(err, error=err) diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py index 5d2b4dfa6285..1fce4048318f 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py @@ -27,7 +27,7 @@ import functools import logging from typing import ( - Any, Callable, Union, Optional, AsyncIterator as AsyncIteratorType, TYPE_CHECKING, overload + Any, Optional, AsyncIterator as AsyncIteratorType, TYPE_CHECKING, overload ) import trio import urllib3 @@ -36,7 +36,9 @@ from azure.core.exceptions import ( ServiceRequestError, - ServiceResponseError + ServiceResponseError, + IncompleteReadError, + HttpResponseError, ) from azure.core.pipeline import Pipeline from ._base import HttpRequest @@ -44,7 +46,7 @@ AsyncHttpResponse, _ResponseStopIteration, _iterate_response_content) -from ._requests_basic import RequestsTransportResponse, _read_raw_stream +from ._requests_basic import RequestsTransportResponse, _read_raw_stream, AzureErrorUnion from ._base_requests_async import RequestsAsyncTransportBase from .._tools import is_rest as _is_rest from .._tools_async import handle_no_stream_rest_response as _handle_no_stream_rest_response @@ -105,6 +107,15 @@ async def __anext__(self): raise StopAsyncIteration() except requests.exceptions.StreamConsumedError: raise + except requests.exceptions.ChunkedEncodingError as err: + msg = err.__str__() + if 'IncompleteRead' in msg: + _LOGGER.warning("Incomplete download: %s", err) + internal_response.close() + raise IncompleteReadError(err, error=err) + _LOGGER.warning("Unable to stream download: %s", err) + internal_response.close() + raise HttpResponseError(err, error=err) except Exception as err: _LOGGER.warning("Unable to stream download: %s", err) internal_response.close() @@ -184,7 +195,7 @@ async def send(self, request, **kwargs: Any): # pylint:disable=invalid-overridd self.open() trio_limiter = kwargs.get("trio_limiter", None) response = None - error = None # type: Optional[Union[ServiceRequestError, ServiceResponseError]] + error = None # type: Optional[AzureErrorUnion] data_to_send = await self._retrieve_request_data(request) try: try: @@ -217,6 +228,7 @@ async def send(self, request, **kwargs: Any): # pylint:disable=invalid-overridd allow_redirects=False, **kwargs), limiter=trio_limiter) + response.raw.enforce_content_length = True except urllib3.exceptions.NewConnectionError as err: error = ServiceRequestError(err, error=err) @@ -227,6 +239,14 @@ async def send(self, request, **kwargs: Any): # pylint:disable=invalid-overridd error = ServiceResponseError(err, error=err) else: error = ServiceRequestError(err, error=err) + except requests.exceptions.ChunkedEncodingError as err: + msg = err.__str__() + if 'IncompleteRead' in msg: + _LOGGER.warning("Incomplete download: %s", err) + error = IncompleteReadError(err, error=err) + else: + _LOGGER.warning("Unable to stream download: %s", err) + error = HttpResponseError(err, error=err) except requests.RequestException as err: error = ServiceRequestError(err, error=err) diff --git a/sdk/core/azure-core/tests/async_tests/test_content_length_checking_async.py b/sdk/core/azure-core/tests/async_tests/test_content_length_checking_async.py new file mode 100644 index 000000000000..726c4baf32ed --- /dev/null +++ b/sdk/core/azure-core/tests/async_tests/test_content_length_checking_async.py @@ -0,0 +1,28 @@ +# coding: utf-8 +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +from azure.core.pipeline import AsyncPipeline +from azure.core.pipeline.transport import ( + HttpRequest, +) +from azure.core import AsyncPipelineClient +from azure.core.exceptions import IncompleteReadError +import pytest + + +@pytest.mark.asyncio +async def test_aio_transport_short_read_download_stream(port): + url = "http://localhost:{}/errors/short-data".format(port) + client = AsyncPipelineClient(url) + with pytest.raises(IncompleteReadError): + async with client: + request = HttpRequest("GET", url) + pipeline_response = await client._pipeline.run(request, stream=True) + response = pipeline_response.http_response + data = response.stream_download(client._pipeline) + content = b"" + async for d in data: + content += d diff --git a/sdk/core/azure-core/tests/test_content_length_checking.py b/sdk/core/azure-core/tests/test_content_length_checking.py new file mode 100644 index 000000000000..7cdbd8f70c2a --- /dev/null +++ b/sdk/core/azure-core/tests/test_content_length_checking.py @@ -0,0 +1,27 @@ +# coding: utf-8 +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +from azure.core import PipelineClient +from azure.core.pipeline import Pipeline +from azure.core.pipeline.transport import ( + HttpRequest, + RequestsTransport, +) +from azure.core.exceptions import IncompleteReadError +import pytest + + +def test_sync_transport_short_read_download_stream(port): + url = "http://localhost:{}/errors/short-data".format(port) + client = PipelineClient(url) + request = HttpRequest("GET", url) + with pytest.raises(IncompleteReadError): + pipeline_response = client._pipeline.run(request, stream=True) + response = pipeline_response.http_response + data = response.stream_download(client._pipeline) + content = b"" + for d in data: + content += d diff --git a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/errors.py b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/errors.py index 221f598e063a..bab6ef3c4913 100644 --- a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/errors.py +++ b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/errors.py @@ -26,3 +26,10 @@ def __iter__(self): yield b"Hello, " yield b"world!" return Response(StreamingBody(), status=500) + +@errors_api.route('/short-data', methods=['GET']) +def get_short_data(): + response = Response(b"X" * 4, status=200) + response.automatically_set_content_length = False + response.headers["Content-Length"] = "8" + return response