Skip to content

Commit

Permalink
raise IncompleteReadError if only receive partial response (#20888)
Browse files Browse the repository at this point in the history
* raise IncompleteReadError if only receive partial response

* update

* Update CHANGELOG.md

* update

* update

* update

* update

* update

* update

* update

* address review feedback

* update

* update

* update

* update

* Update exceptions.py
  • Loading branch information
xiangyan99 authored Nov 3, 2021
1 parent cbb395e commit 60e11ba
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 12 deletions.
4 changes: 4 additions & 0 deletions sdk/core/azure-core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
### Features Added

- add kwargs to the methods for `iter_raw` and `iter_bytes` #21529
- Added new error type `IncompleteReadError` which is raised if peer closes the connection before we have received the complete message body.

### Breaking Changes

- SansIOHTTPPolicy.on_exception returns None instead of bool.

### Bugs Fixed

- The `Content-Length` header in a http response is strictly checked against the actual number of bytes in the body,
rather than silently truncating data in case the underlying tcp connection is closed prematurely.
(thanks to @jochen-ott-by for the contribution) #20412
- UnboundLocalError when SansIOHTTPPolicy handles an exception #15222

### Other Changes
Expand Down
4 changes: 4 additions & 0 deletions sdk/core/azure-core/azure/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,10 @@ class DecodeError(HttpResponseError):
"""Error raised during response deserialization."""


class IncompleteReadError(DecodeError):
"""Error raised if peer closes the connection before we have received the complete message body."""


class ResourceExistsError(HttpResponseError):
"""An error response with status code 4xx.
This will not be raised directly by the Azure core pipeline."""
Expand Down
15 changes: 13 additions & 2 deletions sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from multidict import CIMultiDict

from azure.core.configuration import ConnectionConfiguration
from azure.core.exceptions import ServiceRequestError, ServiceResponseError
from azure.core.exceptions import ServiceRequestError, ServiceResponseError, IncompleteReadError
from azure.core.pipeline import Pipeline

from ._base import HttpRequest
Expand Down Expand Up @@ -300,6 +300,12 @@ async def __anext__(self):
except _ResponseStopIteration:
internal_response.close()
raise StopAsyncIteration()
except aiohttp.client_exceptions.ClientPayloadError as err:
# This is the case that server closes connection before we finish the reading. aiohttp library
# raises ClientPayloadError.
_LOGGER.warning("Incomplete download: %s", err)
internal_response.close()
raise IncompleteReadError(err, error=err)
except Exception as err:
_LOGGER.warning("Unable to stream download: %s", err)
internal_response.close()
Expand Down Expand Up @@ -384,7 +390,12 @@ def text(self, encoding: Optional[str] = None) -> str:

async def load_body(self) -> None:
"""Load in memory the body, so it could be accessible from sync methods."""
self._content = await self.internal_response.read()
try:
self._content = await self.internal_response.read()
except aiohttp.client_exceptions.ClientPayloadError as err:
# This is the case that server closes connection before we finish the reading. aiohttp library
# raises ClientPayloadError.
raise IncompleteReadError(err, error=err)

def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]:
"""Generator for streaming response body data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,25 @@
import functools
import logging
from typing import (
Any, Union, Optional, AsyncIterator as AsyncIteratorType, TYPE_CHECKING, overload
Any, Optional, AsyncIterator as AsyncIteratorType, TYPE_CHECKING, overload
)
import urllib3 # type: ignore

import requests

from azure.core.exceptions import (
ServiceRequestError,
ServiceResponseError
ServiceResponseError,
IncompleteReadError,
HttpResponseError,
)
from azure.core.pipeline import Pipeline
from ._base import HttpRequest
from ._base_async import (
AsyncHttpResponse,
_ResponseStopIteration,
_iterate_response_content)
from ._requests_basic import RequestsTransportResponse, _read_raw_stream
from ._requests_basic import RequestsTransportResponse, _read_raw_stream, AzureErrorUnion
from ._base_requests_async import RequestsAsyncTransportBase
from .._tools import is_rest as _is_rest
from .._tools_async import handle_no_stream_rest_response as _handle_no_stream_rest_response
Expand Down Expand Up @@ -134,7 +136,7 @@ async def send(self, request, **kwargs): # pylint:disable=invalid-overridden-me
self.open()
loop = kwargs.get("loop", _get_running_loop())
response = None
error = None # type: Optional[Union[ServiceRequestError, ServiceResponseError]]
error = None # type: Optional[AzureErrorUnion]
data_to_send = await self._retrieve_request_data(request)
try:
response = await loop.run_in_executor(
Expand All @@ -151,6 +153,7 @@ async def send(self, request, **kwargs): # pylint:disable=invalid-overridden-me
cert=kwargs.pop('connection_cert', self.connection_config.cert),
allow_redirects=False,
**kwargs))
response.raw.enforce_content_length = True

except urllib3.exceptions.NewConnectionError as err:
error = ServiceRequestError(err, error=err)
Expand All @@ -161,6 +164,14 @@ async def send(self, request, **kwargs): # pylint:disable=invalid-overridden-me
error = ServiceResponseError(err, error=err)
else:
error = ServiceRequestError(err, error=err)
except requests.exceptions.ChunkedEncodingError as err:
msg = err.__str__()
if 'IncompleteRead' in msg:
_LOGGER.warning("Incomplete download: %s", err)
error = IncompleteReadError(err, error=err)
else:
_LOGGER.warning("Unable to stream download: %s", err)
error = HttpResponseError(err, error=err)
except requests.RequestException as err:
error = ServiceRequestError(err, error=err)

Expand Down Expand Up @@ -223,6 +234,15 @@ async def __anext__(self):
raise StopAsyncIteration()
except requests.exceptions.StreamConsumedError:
raise
except requests.exceptions.ChunkedEncodingError as err:
msg = err.__str__()
if 'IncompleteRead' in msg:
_LOGGER.warning("Incomplete download: %s", err)
internal_response.close()
raise IncompleteReadError(err, error=err)
_LOGGER.warning("Unable to stream download: %s", err)
internal_response.close()
raise HttpResponseError(err, error=err)
except Exception as err:
_LOGGER.warning("Unable to stream download: %s", err)
internal_response.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
from azure.core.configuration import ConnectionConfiguration
from azure.core.exceptions import (
ServiceRequestError,
ServiceResponseError
ServiceResponseError,
IncompleteReadError,
HttpResponseError,
)
from . import HttpRequest # pylint: disable=unused-import

Expand All @@ -51,6 +53,13 @@
if TYPE_CHECKING:
from ...rest import HttpRequest as RestHttpRequest, HttpResponse as RestHttpResponse

AzureErrorUnion = Union[
ServiceRequestError,
ServiceResponseError,
IncompleteReadError,
HttpResponseError,
]

PipelineType = TypeVar("PipelineType")

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -79,6 +88,7 @@ def _read_raw_stream(response, chunk_size=1):
# https://github.com/psf/requests/blob/master/requests/models.py#L774
response._content_consumed = True # pylint: disable=protected-access


class _RequestsTransportResponseBase(_HttpResponseBase):
"""Base class for accessing response data.
Expand Down Expand Up @@ -164,6 +174,15 @@ def __next__(self):
raise StopIteration()
except requests.exceptions.StreamConsumedError:
raise
except requests.exceptions.ChunkedEncodingError as err:
msg = err.__str__()
if 'IncompleteRead' in msg:
_LOGGER.warning("Incomplete download: %s", err)
internal_response.close()
raise IncompleteReadError(err, error=err)
_LOGGER.warning("Unable to stream download: %s", err)
internal_response.close()
raise HttpResponseError(err, error=err)
except Exception as err:
_LOGGER.warning("Unable to stream download: %s", err)
internal_response.close()
Expand Down Expand Up @@ -289,7 +308,7 @@ def send(self, request, **kwargs): # type: ignore
"""
self.open()
response = None
error = None # type: Optional[Union[ServiceRequestError, ServiceResponseError]]
error = None # type: Optional[AzureErrorUnion]

try:
connection_timeout = kwargs.pop('connection_timeout', self.connection_config.timeout)
Expand All @@ -313,6 +332,7 @@ def send(self, request, **kwargs): # type: ignore
cert=kwargs.pop('connection_cert', self.connection_config.cert),
allow_redirects=False,
**kwargs)
response.raw.enforce_content_length = True

except (urllib3.exceptions.NewConnectionError, urllib3.exceptions.ConnectTimeoutError) as err:
error = ServiceRequestError(err, error=err)
Expand All @@ -323,6 +343,14 @@ def send(self, request, **kwargs): # type: ignore
error = ServiceResponseError(err, error=err)
else:
error = ServiceRequestError(err, error=err)
except requests.exceptions.ChunkedEncodingError as err:
msg = err.__str__()
if 'IncompleteRead' in msg:
_LOGGER.warning("Incomplete download: %s", err)
error = IncompleteReadError(err, error=err)
else:
_LOGGER.warning("Unable to stream download: %s", err)
error = HttpResponseError(err, error=err)
except requests.RequestException as err:
error = ServiceRequestError(err, error=err)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import functools
import logging
from typing import (
Any, Callable, Union, Optional, AsyncIterator as AsyncIteratorType, TYPE_CHECKING, overload
Any, Optional, AsyncIterator as AsyncIteratorType, TYPE_CHECKING, overload
)
import trio
import urllib3
Expand All @@ -36,15 +36,17 @@

from azure.core.exceptions import (
ServiceRequestError,
ServiceResponseError
ServiceResponseError,
IncompleteReadError,
HttpResponseError,
)
from azure.core.pipeline import Pipeline
from ._base import HttpRequest
from ._base_async import (
AsyncHttpResponse,
_ResponseStopIteration,
_iterate_response_content)
from ._requests_basic import RequestsTransportResponse, _read_raw_stream
from ._requests_basic import RequestsTransportResponse, _read_raw_stream, AzureErrorUnion
from ._base_requests_async import RequestsAsyncTransportBase
from .._tools import is_rest as _is_rest
from .._tools_async import handle_no_stream_rest_response as _handle_no_stream_rest_response
Expand Down Expand Up @@ -105,6 +107,15 @@ async def __anext__(self):
raise StopAsyncIteration()
except requests.exceptions.StreamConsumedError:
raise
except requests.exceptions.ChunkedEncodingError as err:
msg = err.__str__()
if 'IncompleteRead' in msg:
_LOGGER.warning("Incomplete download: %s", err)
internal_response.close()
raise IncompleteReadError(err, error=err)
_LOGGER.warning("Unable to stream download: %s", err)
internal_response.close()
raise HttpResponseError(err, error=err)
except Exception as err:
_LOGGER.warning("Unable to stream download: %s", err)
internal_response.close()
Expand Down Expand Up @@ -184,7 +195,7 @@ async def send(self, request, **kwargs: Any): # pylint:disable=invalid-overridd
self.open()
trio_limiter = kwargs.get("trio_limiter", None)
response = None
error = None # type: Optional[Union[ServiceRequestError, ServiceResponseError]]
error = None # type: Optional[AzureErrorUnion]
data_to_send = await self._retrieve_request_data(request)
try:
try:
Expand Down Expand Up @@ -217,6 +228,7 @@ async def send(self, request, **kwargs: Any): # pylint:disable=invalid-overridd
allow_redirects=False,
**kwargs),
limiter=trio_limiter)
response.raw.enforce_content_length = True

except urllib3.exceptions.NewConnectionError as err:
error = ServiceRequestError(err, error=err)
Expand All @@ -227,6 +239,14 @@ async def send(self, request, **kwargs: Any): # pylint:disable=invalid-overridd
error = ServiceResponseError(err, error=err)
else:
error = ServiceRequestError(err, error=err)
except requests.exceptions.ChunkedEncodingError as err:
msg = err.__str__()
if 'IncompleteRead' in msg:
_LOGGER.warning("Incomplete download: %s", err)
error = IncompleteReadError(err, error=err)
else:
_LOGGER.warning("Unable to stream download: %s", err)
error = HttpResponseError(err, error=err)
except requests.RequestException as err:
error = ServiceRequestError(err, error=err)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# coding: utf-8
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
from azure.core.pipeline import AsyncPipeline
from azure.core.pipeline.transport import (
HttpRequest,
)
from azure.core import AsyncPipelineClient
from azure.core.exceptions import IncompleteReadError
import pytest


@pytest.mark.asyncio
async def test_aio_transport_short_read_download_stream(port):
url = "http://localhost:{}/errors/short-data".format(port)
client = AsyncPipelineClient(url)
with pytest.raises(IncompleteReadError):
async with client:
request = HttpRequest("GET", url)
pipeline_response = await client._pipeline.run(request, stream=True)
response = pipeline_response.http_response
data = response.stream_download(client._pipeline)
content = b""
async for d in data:
content += d
27 changes: 27 additions & 0 deletions sdk/core/azure-core/tests/test_content_length_checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# coding: utf-8
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
from azure.core import PipelineClient
from azure.core.pipeline import Pipeline
from azure.core.pipeline.transport import (
HttpRequest,
RequestsTransport,
)
from azure.core.exceptions import IncompleteReadError
import pytest


def test_sync_transport_short_read_download_stream(port):
url = "http://localhost:{}/errors/short-data".format(port)
client = PipelineClient(url)
request = HttpRequest("GET", url)
with pytest.raises(IncompleteReadError):
pipeline_response = client._pipeline.run(request, stream=True)
response = pipeline_response.http_response
data = response.stream_download(client._pipeline)
content = b""
for d in data:
content += d
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,10 @@ def __iter__(self):
yield b"Hello, "
yield b"world!"
return Response(StreamingBody(), status=500)

@errors_api.route('/short-data', methods=['GET'])
def get_short_data():
response = Response(b"X" * 4, status=200)
response.automatically_set_content_length = False
response.headers["Content-Length"] = "8"
return response

0 comments on commit 60e11ba

Please sign in to comment.