Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

raise IncompleteReadError if only receive partial response #20888

Merged
merged 23 commits into from
Nov 3, 2021
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sdk/core/azure-core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### Features Added

- Added new error type `IncompleteReadError` which is raised if peer closes connection without sending complete message body.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could rephrase as '...if peer closes the connection before we have received the complete message body.'
When I see 'sending', I think 'outgoing data'.

- add kwargs to the methods for `iter_raw` and `iter_bytes` #21529

### Breaking Changes
Expand All @@ -12,6 +13,9 @@

### Bugs Fixed

- The `Content-Length` header in a http response is strictly checked against the actual number of bytes in the body,
rather than silently truncating data in case the underlying tcp connection is closed prematurely.
(thanks to @jochen-ott-by for the contribution) #20412
- UnboundLocalError when SansIOHTTPPolicy handles an exception #15222

### Other Changes
Expand Down
4 changes: 4 additions & 0 deletions sdk/core/azure-core/azure/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,10 @@ class DecodeError(HttpResponseError):
"""Error raised during response deserialization."""


class IncompleteReadError(DecodeError):
annatisch marked this conversation as resolved.
Show resolved Hide resolved
"""Error raised if peer closes connection without sending complete message body."""


class ResourceExistsError(HttpResponseError):
"""An error response with status code 4xx.
This will not be raised directly by the Azure core pipeline."""
Expand Down
15 changes: 13 additions & 2 deletions sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from multidict import CIMultiDict

from azure.core.configuration import ConnectionConfiguration
from azure.core.exceptions import ServiceRequestError, ServiceResponseError
from azure.core.exceptions import ServiceRequestError, ServiceResponseError, IncompleteReadError
from azure.core.pipeline import Pipeline

from ._base import HttpRequest
Expand Down Expand Up @@ -300,6 +300,12 @@ async def __anext__(self):
except _ResponseStopIteration:
internal_response.close()
raise StopAsyncIteration()
except aiohttp.client_exceptions.ClientPayloadError as err:
xiangyan99 marked this conversation as resolved.
Show resolved Hide resolved
# This is the case that server closes connection before we finish the reading. aiohttp library
# raises ClientPayloadError.
_LOGGER.warning("Incomplete download: %s", err)
internal_response.close()
raise IncompleteReadError(err, error=err)
except Exception as err:
_LOGGER.warning("Unable to stream download: %s", err)
internal_response.close()
Expand Down Expand Up @@ -384,7 +390,12 @@ def text(self, encoding: Optional[str] = None) -> str:

async def load_body(self) -> None:
"""Load in memory the body, so it could be accessible from sync methods."""
self._content = await self.internal_response.read()
try:
self._content = await self.internal_response.read()
except aiohttp.client_exceptions.ClientPayloadError as err:
# This is the case that server closes connection before we finish the reading. aiohttp library
# raises ClientPayloadError.
raise IncompleteReadError(err, error=err)

def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]:
"""Generator for streaming response body data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,25 @@
import functools
import logging
from typing import (
Any, Union, Optional, AsyncIterator as AsyncIteratorType, TYPE_CHECKING, overload
Any, Optional, AsyncIterator as AsyncIteratorType, TYPE_CHECKING, overload
)
import urllib3 # type: ignore

import requests

from azure.core.exceptions import (
ServiceRequestError,
ServiceResponseError
ServiceResponseError,
IncompleteReadError,
HttpResponseError,
)
from azure.core.pipeline import Pipeline
from ._base import HttpRequest
from ._base_async import (
AsyncHttpResponse,
_ResponseStopIteration,
_iterate_response_content)
from ._requests_basic import RequestsTransportResponse, _read_raw_stream
from ._requests_basic import RequestsTransportResponse, _read_raw_stream, AzureErrorUnion
from ._base_requests_async import RequestsAsyncTransportBase
from .._tools import is_rest as _is_rest
from .._tools_async import handle_no_stream_rest_response as _handle_no_stream_rest_response
Expand Down Expand Up @@ -134,7 +136,7 @@ async def send(self, request, **kwargs): # pylint:disable=invalid-overridden-me
self.open()
loop = kwargs.get("loop", _get_running_loop())
response = None
error = None # type: Optional[Union[ServiceRequestError, ServiceResponseError]]
error = None # type: Optional[AzureErrorUnion]
data_to_send = await self._retrieve_request_data(request)
try:
response = await loop.run_in_executor(
Expand All @@ -151,6 +153,7 @@ async def send(self, request, **kwargs): # pylint:disable=invalid-overridden-me
cert=kwargs.pop('connection_cert', self.connection_config.cert),
allow_redirects=False,
**kwargs))
response.raw.enforce_content_length = True

except urllib3.exceptions.NewConnectionError as err:
error = ServiceRequestError(err, error=err)
Expand All @@ -161,6 +164,14 @@ async def send(self, request, **kwargs): # pylint:disable=invalid-overridden-me
error = ServiceResponseError(err, error=err)
else:
error = ServiceRequestError(err, error=err)
except requests.exceptions.ChunkedEncodingError as err:
msg = err.__str__()
if 'IncompleteRead' in msg:
_LOGGER.warning("Incomplete download: %s", err)
error = IncompleteReadError(err, error=err)
else:
_LOGGER.warning("Unable to stream download: %s", err)
error = HttpResponseError(err, error=err)
except requests.RequestException as err:
error = ServiceRequestError(err, error=err)

Expand Down Expand Up @@ -223,6 +234,15 @@ async def __anext__(self):
raise StopAsyncIteration()
except requests.exceptions.StreamConsumedError:
raise
except requests.exceptions.ChunkedEncodingError as err:
msg = err.__str__()
if 'IncompleteRead' in msg:
_LOGGER.warning("Incomplete download: %s", err)
internal_response.close()
raise IncompleteReadError(err, error=err)
_LOGGER.warning("Unable to stream download: %s", err)
internal_response.close()
raise HttpResponseError(err, error=err)
except Exception as err:
_LOGGER.warning("Unable to stream download: %s", err)
internal_response.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
from azure.core.configuration import ConnectionConfiguration
from azure.core.exceptions import (
ServiceRequestError,
ServiceResponseError
ServiceResponseError,
IncompleteReadError,
HttpResponseError,
)
from . import HttpRequest # pylint: disable=unused-import

Expand All @@ -51,6 +53,13 @@
if TYPE_CHECKING:
from ...rest import HttpRequest as RestHttpRequest, HttpResponse as RestHttpResponse

AzureErrorUnion = Union[
ServiceRequestError,
ServiceResponseError,
IncompleteReadError,
HttpResponseError,
]

PipelineType = TypeVar("PipelineType")

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -79,6 +88,7 @@ def _read_raw_stream(response, chunk_size=1):
# https://github.com/psf/requests/blob/master/requests/models.py#L774
response._content_consumed = True # pylint: disable=protected-access


class _RequestsTransportResponseBase(_HttpResponseBase):
"""Base class for accessing response data.
Expand Down Expand Up @@ -164,6 +174,15 @@ def __next__(self):
raise StopIteration()
except requests.exceptions.StreamConsumedError:
raise
except requests.exceptions.ChunkedEncodingError as err:
xiangyan99 marked this conversation as resolved.
Show resolved Hide resolved
msg = err.__str__()
if 'IncompleteRead' in msg:
_LOGGER.warning("Incomplete download: %s", err)
internal_response.close()
raise IncompleteReadError(err, error=err)
_LOGGER.warning("Unable to stream download: %s", err)
internal_response.close()
raise HttpResponseError(err, error=err)
except Exception as err:
_LOGGER.warning("Unable to stream download: %s", err)
internal_response.close()
Expand Down Expand Up @@ -289,7 +308,7 @@ def send(self, request, **kwargs): # type: ignore
"""
self.open()
response = None
error = None # type: Optional[Union[ServiceRequestError, ServiceResponseError]]
error = None # type: Optional[AzureErrorUnion]

try:
connection_timeout = kwargs.pop('connection_timeout', self.connection_config.timeout)
Expand All @@ -313,6 +332,7 @@ def send(self, request, **kwargs): # type: ignore
cert=kwargs.pop('connection_cert', self.connection_config.cert),
allow_redirects=False,
**kwargs)
response.raw.enforce_content_length = True

except (urllib3.exceptions.NewConnectionError, urllib3.exceptions.ConnectTimeoutError) as err:
error = ServiceRequestError(err, error=err)
Expand All @@ -323,6 +343,14 @@ def send(self, request, **kwargs): # type: ignore
error = ServiceResponseError(err, error=err)
else:
error = ServiceRequestError(err, error=err)
except requests.exceptions.ChunkedEncodingError as err:
msg = err.__str__()
if 'IncompleteRead' in msg:
_LOGGER.warning("Incomplete download: %s", err)
error = IncompleteReadError(err, error=err)
else:
_LOGGER.warning("Unable to stream download: %s", err)
error = HttpResponseError(err, error=err)
except requests.RequestException as err:
error = ServiceRequestError(err, error=err)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import functools
import logging
from typing import (
Any, Callable, Union, Optional, AsyncIterator as AsyncIteratorType, TYPE_CHECKING, overload
Any, Optional, AsyncIterator as AsyncIteratorType, TYPE_CHECKING, overload
)
import trio
import urllib3
Expand All @@ -36,15 +36,17 @@

from azure.core.exceptions import (
ServiceRequestError,
ServiceResponseError
ServiceResponseError,
IncompleteReadError,
HttpResponseError,
)
from azure.core.pipeline import Pipeline
from ._base import HttpRequest
from ._base_async import (
AsyncHttpResponse,
_ResponseStopIteration,
_iterate_response_content)
from ._requests_basic import RequestsTransportResponse, _read_raw_stream
from ._requests_basic import RequestsTransportResponse, _read_raw_stream, AzureErrorUnion
from ._base_requests_async import RequestsAsyncTransportBase
from .._tools import is_rest as _is_rest
from .._tools_async import handle_no_stream_rest_response as _handle_no_stream_rest_response
Expand Down Expand Up @@ -105,6 +107,15 @@ async def __anext__(self):
raise StopAsyncIteration()
except requests.exceptions.StreamConsumedError:
raise
except requests.exceptions.ChunkedEncodingError as err:
msg = err.__str__()
if 'IncompleteRead' in msg:
_LOGGER.warning("Incomplete download: %s", err)
internal_response.close()
raise IncompleteReadError(err, error=err)
_LOGGER.warning("Unable to stream download: %s", err)
internal_response.close()
raise HttpResponseError(err, error=err)
except Exception as err:
_LOGGER.warning("Unable to stream download: %s", err)
internal_response.close()
Expand Down Expand Up @@ -184,7 +195,7 @@ async def send(self, request, **kwargs: Any): # pylint:disable=invalid-overridd
self.open()
trio_limiter = kwargs.get("trio_limiter", None)
response = None
error = None # type: Optional[Union[ServiceRequestError, ServiceResponseError]]
error = None # type: Optional[AzureErrorUnion]
data_to_send = await self._retrieve_request_data(request)
try:
try:
Expand Down Expand Up @@ -217,6 +228,7 @@ async def send(self, request, **kwargs: Any): # pylint:disable=invalid-overridd
allow_redirects=False,
**kwargs),
limiter=trio_limiter)
response.raw.enforce_content_length = True

except urllib3.exceptions.NewConnectionError as err:
error = ServiceRequestError(err, error=err)
Expand All @@ -227,6 +239,14 @@ async def send(self, request, **kwargs: Any): # pylint:disable=invalid-overridd
error = ServiceResponseError(err, error=err)
else:
error = ServiceRequestError(err, error=err)
except requests.exceptions.ChunkedEncodingError as err:
msg = err.__str__()
if 'IncompleteRead' in msg:
_LOGGER.warning("Incomplete download: %s", err)
error = IncompleteReadError(err, error=err)
else:
_LOGGER.warning("Unable to stream download: %s", err)
error = HttpResponseError(err, error=err)
except requests.RequestException as err:
error = ServiceRequestError(err, error=err)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# coding: utf-8
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
from azure.core.pipeline import AsyncPipeline
from azure.core.pipeline.transport import (
HttpRequest,
)
from azure.core import AsyncPipelineClient
from azure.core.exceptions import IncompleteReadError
import pytest


@pytest.mark.asyncio
async def test_aio_transport_short_read_download_stream(port):
url = "http://localhost:{}/errors/short-data".format(port)
client = AsyncPipelineClient(url)
with pytest.raises(IncompleteReadError):
async with client:
request = HttpRequest("GET", url)
pipeline_response = await client._pipeline.run(request, stream=True)
response = pipeline_response.http_response
data = response.stream_download(client._pipeline)
content = b""
async for d in data:
content += d
27 changes: 27 additions & 0 deletions sdk/core/azure-core/tests/test_content_length_checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# coding: utf-8
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
from azure.core import PipelineClient
from azure.core.pipeline import Pipeline
from azure.core.pipeline.transport import (
HttpRequest,
RequestsTransport,
)
from azure.core.exceptions import IncompleteReadError
import pytest


def test_sync_transport_short_read_download_stream(port):
url = "http://localhost:{}/errors/short-data".format(port)
client = PipelineClient(url)
request = HttpRequest("GET", url)
with pytest.raises(IncompleteReadError):
pipeline_response = client._pipeline.run(request, stream=True)
response = pipeline_response.http_response
data = response.stream_download(client._pipeline)
content = b""
for d in data:
content += d
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,10 @@ def __iter__(self):
yield b"Hello, "
yield b"world!"
return Response(StreamingBody(), status=500)

@errors_api.route('/short-data', methods=['GET'])
def get_short_data():
response = Response(b"X" * 4, status=200)
response.automatically_set_content_length = False
response.headers["Content-Length"] = "8"
return response