diff --git a/sdk/core/azure-core/CHANGELOG.md b/sdk/core/azure-core/CHANGELOG.md index 8c23c48b7b34..7511f280da26 100644 --- a/sdk/core/azure-core/CHANGELOG.md +++ b/sdk/core/azure-core/CHANGELOG.md @@ -4,7 +4,11 @@ ### Features Added -### Breaking Changes +### Breaking Changes in the Provisional `azure.core.rest` package + +- `azure.core.rest.HttpResponse` and `azure.core.rest.AsyncHttpResponse` are now abstract base classes. They should not be initialized directly, instead +your transport responses should inherit from them and implement them. +- The properties of the `azure.core.rest` responses are now all read-only - HttpLoggingPolicy integrates logs into one record #19925 @@ -24,8 +28,6 @@ - The `text` property on `azure.core.rest.HttpResponse` and `azure.core.rest.AsyncHttpResponse` has changed to a method, which also takes an `encoding` parameter. - Removed `iter_text` and `iter_lines` from `azure.core.rest.HttpResponse` and `azure.core.rest.AsyncHttpResponse` -- `azure.core.rest.HttpResponse` and `azure.core.rest.AsyncHttpResponse` are now abstract base classes. They should not be initialized directly, instead -your transport responses should inherit from them and implement them. ### Bugs Fixed diff --git a/sdk/core/azure-core/azure/core/_pipeline_client.py b/sdk/core/azure-core/azure/core/_pipeline_client.py index f29f391c4ed2..727c206ccb75 100644 --- a/sdk/core/azure-core/azure/core/_pipeline_client.py +++ b/sdk/core/azure-core/azure/core/_pipeline_client.py @@ -39,7 +39,6 @@ RequestIdPolicy, RetryPolicy, ) -from .pipeline._tools import to_rest_response as _to_rest_response try: from typing import TYPE_CHECKING @@ -192,22 +191,10 @@ def send_request(self, request, **kwargs): :keyword bool stream: Whether the response payload will be streamed. Defaults to False. :return: The response of your network call. Does not do error handling on your response. :rtype: ~azure.core.rest.HttpResponse - # """ - rest_request = hasattr(request, "content") + """ + stream = kwargs.pop("stream", False) # want to add default value return_pipeline_response = kwargs.pop("_return_pipeline_response", False) - pipeline_response = self._pipeline.run(request, **kwargs) # pylint: disable=protected-access - response = pipeline_response.http_response - if rest_request: - response = _to_rest_response(response) - try: - if not kwargs.get("stream", False): - response.read() - response.close() - except Exception as exc: - response.close() - raise exc + pipeline_response = self._pipeline.run(request, stream=stream, **kwargs) # pylint: disable=protected-access if return_pipeline_response: - pipeline_response.http_response = response - pipeline_response.http_request = request return pipeline_response - return response + return pipeline_response.http_response diff --git a/sdk/core/azure-core/azure/core/_pipeline_client_async.py b/sdk/core/azure-core/azure/core/_pipeline_client_async.py index 9b3674af4703..423d0efa45de 100644 --- a/sdk/core/azure-core/azure/core/_pipeline_client_async.py +++ b/sdk/core/azure-core/azure/core/_pipeline_client_async.py @@ -37,7 +37,6 @@ RequestIdPolicy, AsyncRetryPolicy, ) -from .pipeline._tools_async import to_rest_response as _to_rest_response try: from typing import TYPE_CHECKING, TypeVar @@ -194,30 +193,13 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use return AsyncPipeline(transport, policies) async def _make_pipeline_call(self, request, **kwargs): - rest_request = hasattr(request, "content") return_pipeline_response = kwargs.pop("_return_pipeline_response", False) pipeline_response = await self._pipeline.run( request, **kwargs # pylint: disable=protected-access ) - response = pipeline_response.http_response - if rest_request: - rest_response = _to_rest_response(response) - if not kwargs.get("stream"): - try: - # in this case, the pipeline transport response already called .load_body(), so - # the body is loaded. instead of doing response.read(), going to set the body - # to the internal content - rest_response._content = response.body() # pylint: disable=protected-access - await rest_response._set_read_checks() # pylint: disable=protected-access - except Exception as exc: - await rest_response.close() - raise exc - response = rest_response if return_pipeline_response: - pipeline_response.http_response = response - pipeline_response.http_request = request return pipeline_response - return response + return pipeline_response.http_response def send_request( self, diff --git a/sdk/core/azure-core/azure/core/pipeline/_tools.py b/sdk/core/azure-core/azure/core/pipeline/_tools.py index 9e036bf43d98..a9ec65e171e9 100644 --- a/sdk/core/azure-core/azure/core/pipeline/_tools.py +++ b/sdk/core/azure-core/azure/core/pipeline/_tools.py @@ -23,6 +23,11 @@ # IN THE SOFTWARE. # # -------------------------------------------------------------------------- +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any + from azure.core.rest import HttpResponse as RestHttpResponse def await_result(func, *args, **kwargs): """If func returns an awaitable, raise that this runner can't handle it.""" @@ -33,38 +38,26 @@ def await_result(func, *args, **kwargs): ) return result -def to_rest_request(pipeline_transport_request): - from ..rest import HttpRequest as RestHttpRequest - return RestHttpRequest( - method=pipeline_transport_request.method, - url=pipeline_transport_request.url, - headers=pipeline_transport_request.headers, - files=pipeline_transport_request.files, - data=pipeline_transport_request.data - ) - -def to_rest_response(pipeline_transport_response): - from .transport._requests_basic import RequestsTransportResponse - from ..rest._requests_basic import RestRequestsTransportResponse - if isinstance(pipeline_transport_response, RequestsTransportResponse): - response_type = RestRequestsTransportResponse - else: - raise ValueError("Unknown transport response") - response = response_type( - request=to_rest_request(pipeline_transport_response.request), - internal_response=pipeline_transport_response.internal_response, - block_size=pipeline_transport_response.block_size - ) - return response +def is_rest(obj): + # type: (Any) -> bool + """Return whether a request or a response is a rest request / response. -def get_block_size(response): - try: - return response._block_size # pylint: disable=protected-access - except AttributeError: - return response.block_size + Checking whether the response has the object content can sometimes result + in a ResponseNotRead error if you're checking the value on a response + that has not been read in yet. To get around this, we also have added + a check for is_stream_consumed, which is an exclusive property on our new responses. + """ + return hasattr(obj, "is_stream_consumed") or hasattr(obj, "content") -def get_internal_response(response): +def handle_non_stream_rest_response(response): + # type: (RestHttpResponse) -> None + """Handle reading and closing of non stream rest responses. + For our new rest responses, we have to call .read() and .close() for our non-stream + responses. This way, we load in the body for users to access. + """ try: - return response._internal_response # pylint: disable=protected-access - except AttributeError: - return response.internal_response + response.read() + response.close() + except Exception as exc: + response.close() + raise exc diff --git a/sdk/core/azure-core/azure/core/pipeline/_tools_async.py b/sdk/core/azure-core/azure/core/pipeline/_tools_async.py index 8eaf4a46ec0f..c7defd96c62a 100644 --- a/sdk/core/azure-core/azure/core/pipeline/_tools_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/_tools_async.py @@ -23,7 +23,10 @@ # IN THE SOFTWARE. # # -------------------------------------------------------------------------- -from ._tools import to_rest_request +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ..rest import AsyncHttpResponse as RestAsyncHttpResponse + async def await_result(func, *args, **kwargs): """If func returns an awaitable, await it.""" @@ -33,35 +36,14 @@ async def await_result(func, *args, **kwargs): return await result # type: ignore return result -def _get_response_type(pipeline_transport_response): - try: - from .transport import AioHttpTransportResponse - from ..rest._aiohttp import RestAioHttpTransportResponse - if isinstance(pipeline_transport_response, AioHttpTransportResponse): - return RestAioHttpTransportResponse - except ImportError: - pass +async def handle_no_stream_rest_response(response: "RestAsyncHttpResponse") -> None: + """Handle reading and closing of non stream rest responses. + For our new rest responses, we have to call .read() and .close() for our non-stream + responses. This way, we load in the body for users to access. + """ try: - from .transport import AsyncioRequestsTransportResponse - from ..rest._requests_asyncio import RestAsyncioRequestsTransportResponse - if isinstance(pipeline_transport_response, AsyncioRequestsTransportResponse): - return RestAsyncioRequestsTransportResponse - except ImportError: - pass - try: - from .transport import TrioRequestsTransportResponse - from ..rest._requests_trio import RestTrioRequestsTransportResponse - if isinstance(pipeline_transport_response, TrioRequestsTransportResponse): - return RestTrioRequestsTransportResponse - except ImportError: - pass - raise ValueError("Unknown transport response") - -def to_rest_response(pipeline_transport_response): - response_type = _get_response_type(pipeline_transport_response) - response = response_type( - request=to_rest_request(pipeline_transport_response.request), - internal_response=pipeline_transport_response.internal_response, - block_size=pipeline_transport_response.block_size, - ) - return response + await response.read() + await response.close() + except Exception as exc: + await response.close() + raise exc diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_universal.py b/sdk/core/azure-core/azure/core/pipeline/policies/_universal.py index 8f18d04883bb..178db95a7dd7 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_universal.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_universal.py @@ -617,6 +617,11 @@ def deserialize_from_http_generics( mime_type = "application/json" # Rely on transport implementation to give me "text()" decoded correctly + if hasattr(response, "read"): + # since users can call deserialize_from_http_generics by themselves + # we want to make sure our new responses are read before we try to + # deserialize + response.read() return cls.deserialize_from_text(response.text(encoding), mime_type, response=response) def on_request(self, request): diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py index c5cd6816f2f5..dc03b29639f1 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py @@ -24,7 +24,9 @@ # # -------------------------------------------------------------------------- import sys -from typing import Any, Optional, AsyncIterator as AsyncIteratorType +from typing import ( + Any, Optional, AsyncIterator as AsyncIteratorType, TYPE_CHECKING, overload +) from collections.abc import AsyncIterator try: import cchardet as chardet @@ -46,7 +48,14 @@ AsyncHttpTransport, AsyncHttpResponse, _ResponseStopIteration) -from .._tools import get_block_size as _get_block_size, get_internal_response as _get_internal_response +from ...utils._pipeline_transport_rest_shared import _aiohttp_body_helper +from .._tools import is_rest as _is_rest +from .._tools_async import handle_no_stream_rest_response as _handle_no_stream_rest_response +if TYPE_CHECKING: + from ...rest import ( + HttpRequest as RestHttpRequest, + AsyncHttpResponse as RestAsyncHttpResponse, + ) # Matching requests, because why not? CONTENT_CHUNK_SIZE = 10 * 1024 @@ -135,6 +144,7 @@ def _get_request_data(self, request): #pylint: disable=no-self-use return form_data return request.data + @overload async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpResponse]: """Send the request using this HTTP sender. @@ -151,6 +161,41 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR :keyword dict proxies: dict of proxy to used based on protocol. Proxy is a dict (protocol, url) :keyword str proxy: will define the proxy to use all the time """ + + @overload + async def send(self, request: "RestHttpRequest", **config: Any) -> Optional["RestAsyncHttpResponse"]: + """Send the `azure.core.rest` request using this HTTP sender. + + Will pre-load the body into memory to be available with a sync method. + Pass stream=True to avoid this behavior. + + :param request: The HttpRequest object + :type request: ~azure.core.rest.HttpRequest + :param config: Any keyword arguments + :return: The AsyncHttpResponse + :rtype: ~azure.core.rest.AsyncHttpResponse + + :keyword bool stream: Defaults to False. + :keyword dict proxies: dict of proxy to used based on protocol. Proxy is a dict (protocol, url) + :keyword str proxy: will define the proxy to use all the time + """ + + async def send(self, request, **config): + """Send the request using this HTTP sender. + + Will pre-load the body into memory to be available with a sync method. + Pass stream=True to avoid this behavior. + + :param request: The HttpRequest object + :type request: ~azure.core.pipeline.transport.HttpRequest + :param config: Any keyword arguments + :return: The AsyncHttpResponse + :rtype: ~azure.core.pipeline.transport.AsyncHttpResponse + + :keyword bool stream: Defaults to False. + :keyword dict proxies: dict of proxy to used based on protocol. Proxy is a dict (protocol, url) + :keyword str proxy: will define the proxy to use all the time + """ await self.open() try: auto_decompress = self.session.auto_decompress # type: ignore @@ -168,7 +213,7 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR config['proxy'] = proxies[protocol] break - response = None + response: Optional["HTTPResponseType"] = None config['ssl'] = self._build_ssl_config( cert=config.pop('connection_cert', self.connection_config.cert), verify=config.pop('connection_verify', self.connection_config.verify) @@ -192,11 +237,22 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR allow_redirects=False, **config ) - response = AioHttpTransportResponse(request, result, - self.connection_config.data_block_size, - decompress=not auto_decompress) - if not stream_response: - await response.load_body() + if _is_rest(request): + from azure.core.rest._aiohttp import RestAioHttpTransportResponse + response = RestAioHttpTransportResponse( + request=request, + internal_response=result, + block_size=self.connection_config.data_block_size, + decompress=not auto_decompress + ) + if not stream_response: + await _handle_no_stream_rest_response(response) + else: + response = AioHttpTransportResponse(request, result, + self.connection_config.data_block_size, + decompress=not auto_decompress) + if not stream_response: + await response.load_body() except aiohttp.client_exceptions.ClientResponseError as err: raise ServiceResponseError(err, error=err) from err except aiohttp.client_exceptions.ClientError as err: @@ -217,9 +273,9 @@ def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, *, decompres self.pipeline = pipeline self.request = response.request self.response = response - self.block_size = _get_block_size(response) + self.block_size = response.block_size self._decompress = decompress - internal_response = _get_internal_response(response) + internal_response = response.internal_response self.content_length = int(internal_response.headers.get('Content-Length', 0)) self._decompressor = None @@ -227,7 +283,7 @@ def __len__(self): return self.content_length async def __anext__(self): - internal_response = _get_internal_response(self.response) + internal_response = self.response.internal_response try: chunk = await internal_response.content.read(self.block_size) if not chunk: @@ -274,30 +330,14 @@ def __init__(self, request: HttpRequest, self.headers = CIMultiDict(aiohttp_response.headers) self.reason = aiohttp_response.reason self.content_type = aiohttp_response.headers.get('content-type') - self._body = None - self._decompressed_body = None + self._content = None + self._decompressed_content = None self._decompress = decompress def body(self) -> bytes: """Return the whole body as bytes in memory. """ - if self._body is None: - raise ValueError("Body is not available. Call async method load_body, or do your call with stream=False.") - if not self._decompress: - return self._body - enc = self.headers.get('Content-Encoding') - if not enc: - return self._body - enc = enc.lower() - if enc in ("gzip", "deflate"): - if self._decompressed_body: - return self._decompressed_body - import zlib - zlib_mode = 16 + zlib.MAX_WBITS if enc == "gzip" else zlib.MAX_WBITS - decompressor = zlib.decompressobj(wbits=zlib_mode) - self._decompressed_body = decompressor.decompress(self._body) - return self._decompressed_body - return self._body + return _aiohttp_body_helper(self) def text(self, encoding: Optional[str] = None) -> str: """Return the whole body as a string. @@ -306,7 +346,7 @@ def text(self, encoding: Optional[str] = None) -> str: :param str encoding: The encoding to apply. """ - # super().text detects charset based on self._body() which is compressed + # super().text detects charset based on self._content() which is compressed # implement the decoding explicitly here body = self.body() @@ -339,7 +379,7 @@ def text(self, encoding: Optional[str] = None) -> str: async def load_body(self) -> None: """Load in memory the body, so it could be accessible from sync methods.""" - self._body = await self.internal_response.read() + self._content = await self.internal_response.read() def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]: """Generator for streaming response body data. diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_base.py b/sdk/core/azure-core/azure/core/pipeline/transport/_base.py index 7c420a392059..75c72ddc5e99 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_base.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_base.py @@ -26,12 +26,6 @@ from __future__ import absolute_import import abc from email.message import Message - -try: - from email import message_from_bytes as message_parser -except ImportError: # 2.7 - from email import message_from_string as message_parser # type: ignore -from io import BytesIO import json import logging import time @@ -67,22 +61,21 @@ from azure.core.pipeline import ( ABC, AbstractContextManager, - PipelineRequest, - PipelineResponse, - PipelineContext, ) -from .._tools import await_result as _await_result from ...utils._utils import _case_insensitive_dict from ...utils._pipeline_transport_rest_shared import ( _format_parameters_helper, _prepare_multipart_body_helper, _serialize_request, _format_data_helper, + BytesIOSocket, + _decode_parts_helper, + _get_raw_parts_helper, + _parts_helper, ) if TYPE_CHECKING: - from ..policies import SansIOHTTPPolicy from collections.abc import MutableMapping HTTPResponseType = TypeVar("HTTPResponseType") @@ -191,7 +184,9 @@ def __deepcopy__(self, memo=None): try: data = copy.deepcopy(self.body, memo) files = copy.deepcopy(self.files, memo) - return HttpRequest(self.method, self.url, self.headers, files, data) + request = HttpRequest(self.method, self.url, self.headers, files, data) + request.multipart_mixed_info = self.multipart_mixed_info + return request except (ValueError, TypeError): return copy.copy(self) @@ -415,27 +410,7 @@ def text(self, encoding=None): def _decode_parts(self, message, http_response_type, requests): # type: (Message, Type[_HttpResponseBase], List[HttpRequest]) -> List[HttpResponse] """Rebuild an HTTP response from pure string.""" - responses = [] - for index, raw_reponse in enumerate(message.get_payload()): - content_type = raw_reponse.get_content_type() - if content_type == "application/http": - responses.append( - _deserialize_response( - raw_reponse.get_payload(decode=True), - requests[index], - http_response_type=http_response_type, - ) - ) - elif content_type == "multipart/mixed" and requests[index].multipart_mixed_info: - # The message batch contains one or more change sets - changeset_requests = requests[index].multipart_mixed_info[0] # type: ignore - changeset_responses = self._decode_parts(raw_reponse, http_response_type, changeset_requests) - responses.extend(changeset_responses) - else: - raise ValueError( - "Multipart doesn't support part other than application/http for now" - ) - return responses + return _decode_parts_helper(self, message, http_response_type, requests, _deserialize_response) def _get_raw_parts(self, http_response_type=None): # type (Optional[Type[_HttpResponseBase]]) -> Iterator[HttpResponse] @@ -444,20 +419,7 @@ def _get_raw_parts(self, http_response_type=None): If parts are application/http use http_response_type or HttpClientTransportResponse as enveloppe. """ - if http_response_type is None: - http_response_type = HttpClientTransportResponse - - body_as_bytes = self.body() - # In order to use email.message parser, I need full HTTP bytes. Faking something to make the parser happy - http_body = ( - b"Content-Type: " - + self.content_type.encode("ascii") - + b"\r\n\r\n" - + body_as_bytes - ) - message = message_parser(http_body) # type: Message - requests = self.request.multipart_mixed_info[0] # type: List[HttpRequest] - return self._decode_parts(message, http_response_type, requests) + return _get_raw_parts_helper(self, http_response_type or HttpClientTransportResponse) def raise_for_status(self): # type () -> None @@ -495,36 +457,7 @@ def parts(self): :rtype: iterator[HttpResponse] :raises ValueError: If the content is not multipart/mixed """ - if not self.content_type or not self.content_type.startswith("multipart/mixed"): - raise ValueError( - "You can't get parts if the response is not multipart/mixed" - ) - - responses = self._get_raw_parts() - if self.request.multipart_mixed_info: - policies = self.request.multipart_mixed_info[1] # type: List[SansIOHTTPPolicy] - - # Apply on_response concurrently to all requests - import concurrent.futures - - def parse_responses(response): - http_request = response.request - context = PipelineContext(None) - pipeline_request = PipelineRequest(http_request, context) - pipeline_response = PipelineResponse( - http_request, response, context=context - ) - - for policy in policies: - _await_result(policy.on_response, pipeline_request, pipeline_response) - - with concurrent.futures.ThreadPoolExecutor() as executor: - # List comprehension to raise exceptions if happened - [ # pylint: disable=expression-not-assigned, unnecessary-comprehension - _ for _ in executor.map(parse_responses, responses) - ] - - return responses + return _parts_helper(self) class _HttpClientTransportResponse(_HttpResponseBase): @@ -557,20 +490,6 @@ class HttpClientTransportResponse(_HttpClientTransportResponse, HttpResponse): """ -class BytesIOSocket(object): - """Mocking the "makefile" of socket for HTTPResponse. - - This can be used to create a http.client.HTTPResponse object - based on bytes and not a real socket. - """ - - def __init__(self, bytes_data): - self.bytes_data = bytes_data - - def makefile(self, *_): - return BytesIO(self.bytes_data) - - def _deserialize_response( http_response_as_bytes, http_request, http_response_type=HttpClientTransportResponse ): diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py b/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py index 73fcd51bf957..30fe0d81b2a1 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py @@ -32,11 +32,8 @@ from ._base import ( _HttpResponseBase, _HttpClientTransportResponse, - PipelineContext, - PipelineRequest, - PipelineResponse, ) -from .._tools_async import await_result as _await_result +from ...utils._pipeline_transport_rest_shared_async import _PartGenerator try: from contextlib import AbstractAsyncContextManager # type: ignore @@ -70,54 +67,6 @@ def _iterate_response_content(iterator): raise _ResponseStopIteration() -class _PartGenerator(AsyncIterator): - """Until parts is a real async iterator, wrap the sync call. - - :param parts: An iterable of parts - """ - - def __init__(self, response: "AsyncHttpResponse") -> None: - self._response = response - self._parts = None - - async def _parse_response(self): - responses = self._response._get_raw_parts( # pylint: disable=protected-access - http_response_type=AsyncHttpClientTransportResponse - ) - if self._response.request.multipart_mixed_info: - policies = self._response.request.multipart_mixed_info[ - 1 - ] # type: List[SansIOHTTPPolicy] - - async def parse_responses(response): - http_request = response.request - context = PipelineContext(None) - pipeline_request = PipelineRequest(http_request, context) - pipeline_response = PipelineResponse( - http_request, response, context=context - ) - - for policy in policies: - await _await_result( - policy.on_response, pipeline_request, pipeline_response - ) - - # Not happy to make this code asyncio specific, but that's multipart only for now - # If we need trio and multipart, let's reinvesitgate that later - await asyncio.gather(*[parse_responses(res) for res in responses]) - - return responses - - async def __anext__(self): - if not self._parts: - self._parts = iter(await self._parse_response()) - - try: - return next(self._parts) - except StopIteration: - raise StopAsyncIteration() - - class AsyncHttpResponse(_HttpResponseBase): # pylint: disable=abstract-method """An AsyncHttpResponse ABC. @@ -147,7 +96,7 @@ def parts(self) -> AsyncIterator: "You can't get parts if the response is not multipart/mixed" ) - return _PartGenerator(self) + return _PartGenerator(self, default_http_response_type=AsyncHttpClientTransportResponse) class AsyncHttpClientTransportResponse(_HttpClientTransportResponse, AsyncHttpResponse): diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py index e41e4de91325..b5d61aeff474 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py @@ -27,7 +27,9 @@ from collections.abc import AsyncIterator import functools import logging -from typing import Any, Union, Optional, AsyncIterator as AsyncIteratorType +from typing import ( + Any, Union, Optional, AsyncIterator as AsyncIteratorType, TYPE_CHECKING, overload +) import urllib3 # type: ignore import requests @@ -44,8 +46,14 @@ _iterate_response_content) from ._requests_basic import RequestsTransportResponse, _read_raw_stream from ._base_requests_async import RequestsAsyncTransportBase -from .._tools import get_block_size as _get_block_size, get_internal_response as _get_internal_response +from .._tools import is_rest as _is_rest +from .._tools_async import handle_no_stream_rest_response as _handle_no_stream_rest_response +if TYPE_CHECKING: + from ...rest import ( + HttpRequest as RestHttpRequest, + AsyncHttpResponse as RestAsyncHttpResponse + ) _LOGGER = logging.getLogger(__name__) @@ -83,7 +91,35 @@ async def __aexit__(self, *exc_details): # pylint: disable=arguments-differ async def sleep(self, duration): # pylint:disable=invalid-overridden-method await asyncio.sleep(duration) - async def send(self, request: HttpRequest, **kwargs: Any) -> AsyncHttpResponse: # type: ignore # pylint:disable=invalid-overridden-method + @overload # type: ignore + async def send(self, request: HttpRequest, **kwargs: Any) -> AsyncHttpResponse: # pylint:disable=invalid-overridden-method + """Send the request using this HTTP sender. + + :param request: The HttpRequest + :type request: ~azure.core.pipeline.transport.HttpRequest + :return: The AsyncHttpResponse + :rtype: ~azure.core.pipeline.transport.AsyncHttpResponse + + :keyword requests.Session session: will override the driver session and use yours. + Should NOT be done unless really required. Anything else is sent straight to requests. + :keyword dict proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + + @overload # type: ignore + async def send(self, request: "RestHttpRequest", **kwargs: Any) -> "RestAsyncHttpResponse": # pylint:disable=invalid-overridden-method + """Send a `azure.core.rest` request using this HTTP sender. + + :param request: The HttpRequest + :type request: ~azure.core.rest.HttpRequest + :return: The AsyncHttpResponse + :rtype: ~azure.core.rest.AsyncHttpResponse + + :keyword requests.Session session: will override the driver session and use yours. + Should NOT be done unless really required. Anything else is sent straight to requests. + :keyword dict proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + + async def send(self, request, **kwargs): # pylint:disable=invalid-overridden-method """Send the request using this HTTP sender. :param request: The HttpRequest @@ -130,6 +166,16 @@ async def send(self, request: HttpRequest, **kwargs: Any) -> AsyncHttpResponse: if error: raise error + if _is_rest(request): + from azure.core.rest._requests_asyncio import RestAsyncioRequestsTransportResponse + retval = RestAsyncioRequestsTransportResponse( + request=request, + internal_response=response, + block_size=self.connection_config.data_block_size + ) + if not kwargs.get("stream"): + await _handle_no_stream_rest_response(retval) + return retval return AsyncioRequestsTransportResponse(request, response, self.connection_config.data_block_size) @@ -146,11 +192,11 @@ def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> self.pipeline = pipeline self.request = response.request self.response = response - self.block_size = _get_block_size(response) + self.block_size = response.block_size decompress = kwargs.pop("decompress", True) if len(kwargs) > 0: raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0])) - internal_response = _get_internal_response(response) + internal_response = response.internal_response if decompress: self.iter_content_func = internal_response.iter_content(self.block_size) else: @@ -162,7 +208,7 @@ def __len__(self): async def __anext__(self): loop = _get_running_loop() - internal_response = _get_internal_response(self.response) + internal_response = self.response.internal_response try: chunk = await loop.run_in_executor( None, diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py index 28b81d705c16..728ae0ad8566 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py @@ -25,7 +25,7 @@ # -------------------------------------------------------------------------- from __future__ import absolute_import import logging -from typing import Iterator, Optional, Any, Union, TypeVar +from typing import Iterator, Optional, Any, Union, TypeVar, overload, TYPE_CHECKING import urllib3 # type: ignore from urllib3.util.retry import Retry # type: ignore from urllib3.exceptions import ( @@ -46,7 +46,10 @@ _HttpResponseBase ) from ._bigger_block_size_http_adapters import BiggerBlockSizeHTTPAdapter -from .._tools import get_block_size as _get_block_size, get_internal_response as _get_internal_response +from .._tools import is_rest as _is_rest, handle_non_stream_rest_response as _handle_non_stream_rest_response + +if TYPE_CHECKING: + from ...rest import HttpRequest as RestHttpRequest, HttpResponse as RestHttpResponse PipelineType = TypeVar("PipelineType") @@ -132,11 +135,11 @@ def __init__(self, pipeline, response, **kwargs): self.pipeline = pipeline self.request = response.request self.response = response - self.block_size = _get_block_size(response) + self.block_size = response.block_size decompress = kwargs.pop("decompress", True) if len(kwargs) > 0: raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0])) - internal_response = _get_internal_response(response) + internal_response = response.internal_response if decompress: self.iter_content_func = internal_response.iter_content(self.block_size) else: @@ -150,7 +153,7 @@ def __iter__(self): return self def __next__(self): - internal_response = _get_internal_response(self.response) + internal_response = self.response.internal_response try: chunk = next(self.iter_content_func) if not chunk: @@ -242,8 +245,37 @@ def close(self): self._session_owner = False self.session = None - def send(self, request, **kwargs): # type: ignore + @overload + def send(self, request, **kwargs): # type: (HttpRequest, Any) -> HttpResponse + """Send a rest request and get back a rest response. + + :param request: The request object to be sent. + :type request: ~azure.core.pipeline.transport.HttpRequest + :return: An HTTPResponse object. + :rtype: ~azure.core.pipeline.transport.HttpResponse + + :keyword requests.Session session: will override the driver session and use yours. + Should NOT be done unless really required. Anything else is sent straight to requests. + :keyword dict proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + + @overload + def send(self, request, **kwargs): + # type: (RestHttpRequest, Any) -> RestHttpResponse + """Send an `azure.core.rest` request and get back a rest response. + + :param request: The request object to be sent. + :type request: ~azure.core.rest.HttpRequest + :return: An HTTPResponse object. + :rtype: ~azure.core.rest.HttpResponse + + :keyword requests.Session session: will override the driver session and use yours. + Should NOT be done unless really required. Anything else is sent straight to requests. + :keyword dict proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + + def send(self, request, **kwargs): # type: ignore """Send request object according to configuration. :param request: The request object to be sent. @@ -296,4 +328,14 @@ def send(self, request, **kwargs): # type: ignore if error: raise error + if _is_rest(request): + from azure.core.rest._requests_basic import RestRequestsTransportResponse + retval = RestRequestsTransportResponse( + request=request, + internal_response=response, + block_size=self.connection_config.data_block_size + ) + if not kwargs.get('stream'): + _handle_non_stream_rest_response(retval) + return retval return RequestsTransportResponse(request, response, self.connection_config.data_block_size) diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py index e21ee5115327..5d2b4dfa6285 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py @@ -26,7 +26,9 @@ from collections.abc import AsyncIterator import functools import logging -from typing import Any, Callable, Union, Optional, AsyncIterator as AsyncIteratorType +from typing import ( + Any, Callable, Union, Optional, AsyncIterator as AsyncIteratorType, TYPE_CHECKING, overload +) import trio import urllib3 @@ -44,7 +46,13 @@ _iterate_response_content) from ._requests_basic import RequestsTransportResponse, _read_raw_stream from ._base_requests_async import RequestsAsyncTransportBase -from .._tools import get_block_size as _get_block_size, get_internal_response as _get_internal_response +from .._tools import is_rest as _is_rest +from .._tools_async import handle_no_stream_rest_response as _handle_no_stream_rest_response +if TYPE_CHECKING: + from ...rest import ( + HttpRequest as RestHttpRequest, + AsyncHttpResponse as RestAsyncHttpResponse, + ) _LOGGER = logging.getLogger(__name__) @@ -62,11 +70,11 @@ def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> self.pipeline = pipeline self.request = response.request self.response = response - self.block_size = _get_block_size(response) + self.block_size = response.block_size decompress = kwargs.pop("decompress", True) if len(kwargs) > 0: raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0])) - internal_response = _get_internal_response(response) + internal_response = response.internal_response if decompress: self.iter_content_func = internal_response.iter_content(self.block_size) else: @@ -77,7 +85,7 @@ def __len__(self): return self.content_length async def __anext__(self): - internal_response = _get_internal_response(self.response) + internal_response = self.response.internal_response try: try: chunk = await trio.to_thread.run_sync( @@ -133,7 +141,35 @@ async def __aexit__(self, *exc_details): # pylint: disable=arguments-differ async def sleep(self, duration): # pylint:disable=invalid-overridden-method await trio.sleep(duration) - async def send(self, request: HttpRequest, **kwargs: Any) -> AsyncHttpResponse: # type: ignore # pylint:disable=invalid-overridden-method + @overload # type: ignore + async def send(self, request: HttpRequest, **kwargs: Any) -> AsyncHttpResponse: # pylint:disable=invalid-overridden-method + """Send the request using this HTTP sender. + + :param request: The HttpRequest + :type request: ~azure.core.pipeline.transport.HttpRequest + :return: The AsyncHttpResponse + :rtype: ~azure.core.pipeline.transport.AsyncHttpResponse + + :keyword requests.Session session: will override the driver session and use yours. + Should NOT be done unless really required. Anything else is sent straight to requests. + :keyword dict proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + + @overload # type: ignore + async def send(self, request: "RestHttpRequest", **kwargs: Any) -> "RestAsyncHttpResponse": # pylint:disable=invalid-overridden-method + """Send an `azure.core.rest` request using this HTTP sender. + + :param request: The HttpRequest + :type request: ~azure.core.rest.HttpRequest + :return: The AsyncHttpResponse + :rtype: ~azure.core.rest.AsyncHttpResponse + + :keyword requests.Session session: will override the driver session and use yours. + Should NOT be done unless really required. Anything else is sent straight to requests. + :keyword dict proxies: will define the proxy to use. Proxy is a dict (protocol, url) + """ + + async def send(self, request, **kwargs: Any): # pylint:disable=invalid-overridden-method """Send the request using this HTTP sender. :param request: The HttpRequest @@ -196,5 +232,15 @@ async def send(self, request: HttpRequest, **kwargs: Any) -> AsyncHttpResponse: if error: raise error + if _is_rest(request): + from azure.core.rest._requests_trio import RestTrioRequestsTransportResponse + retval = RestTrioRequestsTransportResponse( + request=request, + internal_response=response, + block_size=self.connection_config.data_block_size, + ) + if not kwargs.get("stream"): + await _handle_no_stream_rest_response(retval) + return retval return TrioRequestsTransportResponse(request, response, self.connection_config.data_block_size) diff --git a/sdk/core/azure-core/azure/core/polling/async_base_polling.py b/sdk/core/azure-core/azure/core/polling/async_base_polling.py index 104c0691af11..f8c64cd00735 100644 --- a/sdk/core/azure-core/azure/core/polling/async_base_polling.py +++ b/sdk/core/azure-core/azure/core/polling/async_base_polling.py @@ -32,6 +32,7 @@ LROBasePolling, _raise_if_bad_http_status_and_method, ) +from ..pipeline._tools import is_rest __all__ = ["AsyncLROBasePolling"] @@ -119,7 +120,7 @@ async def request_status(self, status_link): # pylint:disable=invalid-overridde # Re-inject 'x-ms-client-request-id' while polling if "request_id" not in self._operation_config: self._operation_config["request_id"] = self._get_request_id() - if hasattr(self._initial_response.http_response, "content"): + if is_rest(self._initial_response.http_response): # if I am a azure.core.rest.HttpResponse # want to keep making azure.core.rest calls from azure.core.rest import HttpRequest as RestHttpRequest diff --git a/sdk/core/azure-core/azure/core/polling/base_polling.py b/sdk/core/azure-core/azure/core/polling/base_polling.py index e82da1d230c3..23f3a56b196c 100644 --- a/sdk/core/azure-core/azure/core/polling/base_polling.py +++ b/sdk/core/azure-core/azure/core/polling/base_polling.py @@ -31,6 +31,7 @@ from ..exceptions import HttpResponseError, DecodeError from . import PollingMethod from ..pipeline.policies._utils import get_retry_after +from ..pipeline._tools import is_rest if TYPE_CHECKING: from azure.core.pipeline import PipelineResponse @@ -121,8 +122,7 @@ def _is_empty(response): :rtype: bool """ - content = response.content if hasattr(response, "content") else response.body() # type: ignore - return not bool(content) + return not bool(response.body()) class LongRunningOperation(ABC): @@ -578,7 +578,7 @@ def request_status(self, status_link): # Re-inject 'x-ms-client-request-id' while polling if "request_id" not in self._operation_config: self._operation_config["request_id"] = self._get_request_id() - if hasattr(self._initial_response.http_response, "content"): + if is_rest(self._initial_response.http_response): # if I am a azure.core.rest.HttpResponse # want to keep making azure.core.rest calls from azure.core.rest import HttpRequest as RestHttpRequest diff --git a/sdk/core/azure-core/azure/core/rest/_aiohttp.py b/sdk/core/azure-core/azure/core/rest/_aiohttp.py index 1cd830f069c7..67ea2346747b 100644 --- a/sdk/core/azure-core/azure/core/rest/_aiohttp.py +++ b/sdk/core/azure-core/azure/core/rest/_aiohttp.py @@ -28,8 +28,9 @@ from itertools import groupby from typing import AsyncIterator from multidict import CIMultiDict -from ._http_response_impl_async import AsyncHttpResponseImpl +from ._http_response_impl_async import AsyncHttpResponseImpl, AsyncHttpResponseBackcompatMixin from ..pipeline.transport._aiohttp import AioHttpStreamDownloadGenerator +from ..utils._pipeline_transport_rest_shared import _pad_attr_name, _aiohttp_body_helper class _ItemsView(collections.abc.ItemsView): def __init__(self, ref): @@ -114,7 +115,34 @@ def get(self, key, default=None): values = ", ".join(values) return values or default -class RestAioHttpTransportResponse(AsyncHttpResponseImpl): +class _RestAioHttpTransportResponseBackcompatMixin(AsyncHttpResponseBackcompatMixin): + """Backcompat mixin for aiohttp responses. + + Need to add it's own mixin because it has function load_body, which other + transport responses don't have, and also because we need to synchronously + decompress the body if users call .body() + """ + + def body(self) -> bytes: + """Return the whole body as bytes in memory. + + Have to modify the default behavior here. In AioHttp, we do decompression + when accessing the body method. The behavior here is the same as if the + caller did an async read of the response first. But for backcompat reasons, + we need to support this decompression within the synchronous body method. + """ + return _aiohttp_body_helper(self) + + async def _load_body(self) -> None: + """Load in memory the body, so it could be accessible from sync methods.""" + self._content = await self.read() # type: ignore + + def __getattr__(self, attr): + backcompat_attrs = ["load_body"] + attr = _pad_attr_name(attr, backcompat_attrs) + return super().__getattr__(attr) + +class RestAioHttpTransportResponse(AsyncHttpResponseImpl, _RestAioHttpTransportResponseBackcompatMixin): def __init__( self, *, @@ -134,6 +162,7 @@ def __init__( **kwargs ) self._decompress = decompress + self._decompressed_content = None def __getstate__(self): state = self.__dict__.copy() diff --git a/sdk/core/azure-core/azure/core/rest/_helpers.py b/sdk/core/azure-core/azure/core/rest/_helpers.py index 613ab723120a..58b544cc4f15 100644 --- a/sdk/core/azure-core/azure/core/rest/_helpers.py +++ b/sdk/core/azure-core/azure/core/rest/_helpers.py @@ -23,6 +23,7 @@ # IN THE SOFTWARE. # # -------------------------------------------------------------------------- +import copy import codecs import cgi from json import dumps @@ -370,3 +371,7 @@ def _serialize(self): :rtype: bytes """ return _serialize_request(self) + + def _add_backcompat_properties(self, request, memo): + """While deepcopying, we also need to add the private backcompat attrs""" + request._multipart_mixed_info = copy.deepcopy(self._multipart_mixed_info, memo) # pylint: disable=protected-access diff --git a/sdk/core/azure-core/azure/core/rest/_http_response_impl.py b/sdk/core/azure-core/azure/core/rest/_http_response_impl.py index 67a57a676e08..44badd1eb9d6 100644 --- a/sdk/core/azure-core/azure/core/rest/_http_response_impl.py +++ b/sdk/core/azure-core/azure/core/rest/_http_response_impl.py @@ -25,6 +25,7 @@ # -------------------------------------------------------------------------- from json import loads from typing import cast, TYPE_CHECKING +from six.moves.http_client import HTTPResponse as _HTTPResponse from ._helpers import ( get_charset_encoding, decode_to_text, @@ -42,12 +43,116 @@ HttpResponse as _HttpResponse, HttpRequest as _HttpRequest ) +from ..utils._utils import _case_insensitive_dict +from ..utils._pipeline_transport_rest_shared import ( + _pad_attr_name, + BytesIOSocket, + _decode_parts_helper, + _get_raw_parts_helper, + _parts_helper, +) if TYPE_CHECKING: from typing import Any, Optional, Iterator, MutableMapping, Callable +class _HttpResponseBackcompatMixinBase(object): + """Base Backcompat mixin for responses. + + This mixin is used by both sync and async HttpResponse + backcompat mixins. + """ + + def __getattr__(self, attr): + backcompat_attrs = [ + "body", + "internal_response", + "block_size", + "stream_download", + ] + attr = _pad_attr_name(attr, backcompat_attrs) + return self.__getattribute__(attr) + + def __setattr__(self, attr, value): + backcompat_attrs = [ + "block_size", + "internal_response", + "request", + "status_code", + "headers", + "reason", + "content_type", + "stream_download", + ] + attr = _pad_attr_name(attr, backcompat_attrs) + super(_HttpResponseBackcompatMixinBase, self).__setattr__(attr, value) + + def _body(self): + """DEPRECATED: Get the response body. + This is deprecated and will be removed in a later release. + You should get it through the `content` property instead + """ + self.read() + return self.content # pylint: disable=no-member + + def _decode_parts(self, message, http_response_type, requests): + """Helper for _decode_parts. + + Rebuild an HTTP response from pure string. + """ + def _deserialize_response( + http_response_as_bytes, http_request, http_response_type + ): + local_socket = BytesIOSocket(http_response_as_bytes) + response = _HTTPResponse(local_socket, method=http_request.method) + response.begin() + return http_response_type(request=http_request, internal_response=response) + + return _decode_parts_helper( + self, + message, + http_response_type or RestHttpClientTransportResponse, + requests, + _deserialize_response + ) + + def _get_raw_parts(self, http_response_type=None): + """Helper for get_raw_parts + + Assuming this body is multipart, return the iterator or parts. + + If parts are application/http use http_response_type or HttpClientTransportResponse + as enveloppe. + """ + return _get_raw_parts_helper( + self, http_response_type or RestHttpClientTransportResponse + ) + + def _stream_download(self, pipeline, **kwargs): + """DEPRECATED: Generator for streaming request body data. + This is deprecated and will be removed in a later release. + You should use `iter_bytes` or `iter_raw` instead. + :rtype: iterator[bytes] + """ + return self._stream_download_generator(pipeline, self, **kwargs) + +class HttpResponseBackcompatMixin(_HttpResponseBackcompatMixinBase): + """Backcompat mixin for sync HttpResponses""" + + def __getattr__(self, attr): + backcompat_attrs = ["parts"] + attr = _pad_attr_name(attr, backcompat_attrs) + return super(HttpResponseBackcompatMixin, self).__getattr__(attr) + + def parts(self): + """DEPRECATED: Assuming the content-type is multipart/mixed, will return the parts as an async iterator. + This is deprecated and will be removed in a later release. + :rtype: Iterator + :raises ValueError: If the content is not multipart/mixed + """ + return _parts_helper(self) + -class _HttpResponseBaseImpl(_HttpResponseBase): # pylint: disable=too-many-instance-attributes +class _HttpResponseBaseImpl(_HttpResponseBase, _HttpResponseBackcompatMixinBase): # pylint: disable=too-many-instance-attributes """Base Implementation class for azure.core.rest.HttpRespone and azure.core.rest.AsyncHttpResponse Since the rest responses are abstract base classes, we need to implement them for each of our transport @@ -239,7 +344,7 @@ def __repr__(self): self.status_code, self.reason, content_type_str ) -class HttpResponseImpl(_HttpResponseBaseImpl, _HttpResponse): +class HttpResponseImpl(_HttpResponseBaseImpl, _HttpResponse, HttpResponseBackcompatMixin): """HttpResponseImpl built on top of our HttpResponse protocol class. Since ~azure.core.rest.HttpResponse is an abstract base class, we need to @@ -321,3 +426,40 @@ def iter_raw(self): ): yield part self.close() + +class _RestHttpClientTransportResponseBackcompatBaseMixin(_HttpResponseBackcompatMixinBase): + + def body(self): + if self._content is None: + self._content = self.internal_response.read() + return self.content + +class _RestHttpClientTransportResponseBase(_HttpResponseBaseImpl, _RestHttpClientTransportResponseBackcompatBaseMixin): + + def __init__(self, **kwargs): + internal_response = kwargs.pop("internal_response") + headers = _case_insensitive_dict(internal_response.getheaders()) + super(_RestHttpClientTransportResponseBase, self).__init__( + internal_response=internal_response, + status_code=internal_response.status, + reason=internal_response.reason, + headers=headers, + content_type=headers.get("Content-Type"), + stream_download_generator=None, + **kwargs + ) + +class RestHttpClientTransportResponse(_RestHttpClientTransportResponseBase, HttpResponseImpl): + """Create a Rest HTTPResponse from an http.client response. + """ + + def iter_bytes(self): + raise TypeError("We do not support iter_bytes for this transport response") + + def iter_raw(self): + raise TypeError("We do not support iter_raw for this transport response") + + def read(self): + if self._content is None: + self._content = self._internal_response.read() + return self._content diff --git a/sdk/core/azure-core/azure/core/rest/_http_response_impl_async.py b/sdk/core/azure-core/azure/core/rest/_http_response_impl_async.py index 8ea5a1d03d25..6bc93ea0b2e0 100644 --- a/sdk/core/azure-core/azure/core/rest/_http_response_impl_async.py +++ b/sdk/core/azure-core/azure/core/rest/_http_response_impl_async.py @@ -25,9 +25,35 @@ # -------------------------------------------------------------------------- from typing import AsyncIterator from ._rest_py3 import AsyncHttpResponse as _AsyncHttpResponse -from ._http_response_impl import _HttpResponseBaseImpl +from ._http_response_impl import ( + _HttpResponseBaseImpl, _HttpResponseBackcompatMixinBase, _RestHttpClientTransportResponseBase +) +from ..utils._pipeline_transport_rest_shared import _pad_attr_name +from ..utils._pipeline_transport_rest_shared_async import _PartGenerator -class AsyncHttpResponseImpl(_HttpResponseBaseImpl, _AsyncHttpResponse): + +class AsyncHttpResponseBackcompatMixin(_HttpResponseBackcompatMixinBase): + """Backcompat mixin for async responses""" + + def __getattr__(self, attr): + backcompat_attrs = ["parts"] + attr = _pad_attr_name(attr, backcompat_attrs) + return super().__getattr__(attr) + + def parts(self): + """DEPRECATED: Assuming the content-type is multipart/mixed, will return the parts as an async iterator. + This is deprecated and will be removed in a later release. + :rtype: AsyncIterator + :raises ValueError: If the content is not multipart/mixed + """ + if not self.content_type or not self.content_type.startswith("multipart/mixed"): + raise ValueError( + "You can't get parts if the response is not multipart/mixed" + ) + + return _PartGenerator(self, default_http_response_type=RestAsyncHttpClientTransportResponse) + +class AsyncHttpResponseImpl(_HttpResponseBaseImpl, _AsyncHttpResponse, AsyncHttpResponseBackcompatMixin): """AsyncHttpResponseImpl built on top of our HttpResponse protocol class. Since ~azure.core.rest.AsyncHttpResponse is an abstract base class, we need to @@ -114,3 +140,18 @@ def __repr__(self) -> str: return "".format( self.status_code, self.reason, content_type_str ) + +class RestAsyncHttpClientTransportResponse(_RestHttpClientTransportResponseBase, AsyncHttpResponseImpl): + """Create a Rest HTTPResponse from an http.client response. + """ + + async def iter_bytes(self): + raise TypeError("We do not support iter_bytes for this transport response") + + async def iter_raw(self): + raise TypeError("We do not support iter_raw for this transport response") + + async def read(self): + if self._content is None: + self._content = self._internal_response.read() + return self._content diff --git a/sdk/core/azure-core/azure/core/rest/_requests_basic.py b/sdk/core/azure-core/azure/core/rest/_requests_basic.py index 46e898cda44f..d62fc9d552c8 100644 --- a/sdk/core/azure-core/azure/core/rest/_requests_basic.py +++ b/sdk/core/azure-core/azure/core/rest/_requests_basic.py @@ -30,7 +30,7 @@ from requests.structures import CaseInsensitiveDict -from ._http_response_impl import _HttpResponseBaseImpl, HttpResponseImpl +from ._http_response_impl import _HttpResponseBaseImpl, HttpResponseImpl, _HttpResponseBackcompatMixinBase from ..pipeline.transport._requests_basic import StreamDownloadGenerator class _ItemsView(collections.ItemsView): @@ -56,7 +56,22 @@ def items(self): """Return a new view of the dictionary's items.""" return _ItemsView(self) -class _RestRequestsTransportResponseBase(_HttpResponseBaseImpl): +class _RestRequestsTransportResponseBaseMixin(_HttpResponseBackcompatMixinBase): + """Backcompat mixin for the sync and async requests responses + + Overriding the default mixin behavior here because we need to synchronously + read the response's content for the async requests responses + """ + + def _body(self): + # Since requests is not an async library, for backcompat, users should + # be able to access the body directly without loading it first (like we have to do + # in aiohttp). So here, we set self._content to self._internal_response.content, + # which is similar to read, without the async call. + self._content = self._internal_response.content + return self._content + +class _RestRequestsTransportResponseBase(_HttpResponseBaseImpl, _RestRequestsTransportResponseBaseMixin): def __init__(self, **kwargs): internal_response = kwargs.pop("internal_response") content = None diff --git a/sdk/core/azure-core/azure/core/rest/_rest.py b/sdk/core/azure-core/azure/core/rest/_rest.py index 734d28017ae2..437e0823bac6 100644 --- a/sdk/core/azure-core/azure/core/rest/_rest.py +++ b/sdk/core/azure-core/azure/core/rest/_rest.py @@ -179,6 +179,7 @@ def __deepcopy__(self, memo=None): ) request._data = copy.deepcopy(self._data, memo) request._files = copy.deepcopy(self._files, memo) + self._add_backcompat_properties(request, memo) return request except (ValueError, TypeError): return copy.copy(self) diff --git a/sdk/core/azure-core/azure/core/rest/_rest_py3.py b/sdk/core/azure-core/azure/core/rest/_rest_py3.py index 2d96956434c1..b3fa25445f67 100644 --- a/sdk/core/azure-core/azure/core/rest/_rest_py3.py +++ b/sdk/core/azure-core/azure/core/rest/_rest_py3.py @@ -173,6 +173,7 @@ def __deepcopy__(self, memo=None) -> "HttpRequest": ) request._data = copy.deepcopy(self._data, memo) request._files = copy.deepcopy(self._files, memo) + self._add_backcompat_properties(request, memo) return request except (ValueError, TypeError): return copy.copy(self) diff --git a/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared.py b/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared.py index 8a66e238090e..686bf4e95ce0 100644 --- a/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared.py +++ b/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared.py @@ -5,10 +5,17 @@ # license information. # -------------------------------------------------------------------------- from __future__ import absolute_import + +from io import BytesIO +from email.message import Message + +try: + from email import message_from_bytes as message_parser +except ImportError: # 2.7 + from email import message_from_string as message_parser # type: ignore import os from typing import TYPE_CHECKING, cast, IO -from email.message import Message from six.moves.http_client import HTTPConnection try: @@ -18,6 +25,13 @@ binary_type = bytes # type: ignore from urllib.parse import urlparse +from ..pipeline import ( + PipelineRequest, + PipelineResponse, + PipelineContext, +) +from ..pipeline._tools import await_result as _await_result + if TYPE_CHECKING: from typing import ( # pylint: disable=ungrouped-imports Dict, @@ -25,6 +39,9 @@ Union, Tuple, Optional, + Callable, + Type, + Iterator, ) # importing both the py3 RestHttpRequest and the fallback RestHttpRequest from azure.core.rest._rest_py3 import HttpRequest as RestHttpRequestPy3 @@ -35,6 +52,26 @@ HTTPRequestType = Union[ RestHttpRequestPy3, RestHttpRequestPy2, PipelineTransportHttpRequest ] + from ..pipeline.policies import SansIOHTTPPolicy + from azure.core.pipeline.transport import ( + HttpResponse as PipelineTransportHttpResponse, + AioHttpTransportResponse as PipelineTransportAioHttpTransportResponse, + ) + from azure.core.pipeline.transport._base import ( + _HttpResponseBase as PipelineTransportHttpResponseBase + ) + +class BytesIOSocket(object): + """Mocking the "makefile" of socket for HTTPResponse. + This can be used to create a http.client.HTTPResponse object + based on bytes and not a real socket. + """ + + def __init__(self, bytes_data): + self.bytes_data = bytes_data + + def makefile(self, *_): + return BytesIO(self.bytes_data) def _format_parameters_helper(http_request, params): """Helper for format_parameters. @@ -182,6 +219,98 @@ def _serialize_request(http_request): ) return serializer.buffer +def _decode_parts_helper( + response, # type: PipelineTransportHttpResponseBase + message, # type: Message + http_response_type, # type: Type[PipelineTransportHttpResponseBase] + requests, # type: List[PipelineTransportHttpRequest] + deserialize_response # type: Callable +): + # type: (...) -> List[PipelineTransportHttpResponse] + """Helper for _decode_parts. + + Rebuild an HTTP response from pure string. + """ + responses = [] + for index, raw_reponse in enumerate(message.get_payload()): + content_type = raw_reponse.get_content_type() + if content_type == "application/http": + responses.append( + deserialize_response( + raw_reponse.get_payload(decode=True), + requests[index], + http_response_type=http_response_type, + ) + ) + elif content_type == "multipart/mixed" and requests[index].multipart_mixed_info: + # The message batch contains one or more change sets + changeset_requests = requests[index].multipart_mixed_info[0] # type: ignore + changeset_responses = response._decode_parts(raw_reponse, http_response_type, changeset_requests) # pylint: disable=protected-access + responses.extend(changeset_responses) + else: + raise ValueError( + "Multipart doesn't support part other than application/http for now" + ) + return responses + +def _get_raw_parts_helper(response, http_response_type): + """Helper for _get_raw_parts + + Assuming this body is multipart, return the iterator or parts. + + If parts are application/http use http_response_type or HttpClientTransportResponse + as enveloppe. + """ + body_as_bytes = response.body() + # In order to use email.message parser, I need full HTTP bytes. Faking something to make the parser happy + http_body = ( + b"Content-Type: " + + response.content_type.encode("ascii") + + b"\r\n\r\n" + + body_as_bytes + ) + message = message_parser(http_body) # type: Message + requests = response.request.multipart_mixed_info[0] + return response._decode_parts(message, http_response_type, requests) # pylint: disable=protected-access + +def _parts_helper(response): + # type: (PipelineTransportHttpResponse) -> Iterator[PipelineTransportHttpResponse] + """Assuming the content-type is multipart/mixed, will return the parts as an iterator. + + :rtype: iterator[HttpResponse] + :raises ValueError: If the content is not multipart/mixed + """ + if not response.content_type or not response.content_type.startswith("multipart/mixed"): + raise ValueError( + "You can't get parts if the response is not multipart/mixed" + ) + + responses = response._get_raw_parts() # pylint: disable=protected-access + if response.request.multipart_mixed_info: + policies = response.request.multipart_mixed_info[1] # type: List[SansIOHTTPPolicy] + + # Apply on_response concurrently to all requests + import concurrent.futures + + def parse_responses(response): + http_request = response.request + context = PipelineContext(None) + pipeline_request = PipelineRequest(http_request, context) + pipeline_response = PipelineResponse( + http_request, response, context=context + ) + + for policy in policies: + _await_result(policy.on_response, pipeline_request, pipeline_response) + + with concurrent.futures.ThreadPoolExecutor() as executor: + # List comprehension to raise exceptions if happened + [ # pylint: disable=expression-not-assigned, unnecessary-comprehension + _ for _ in executor.map(parse_responses, responses) + ] + + return responses + def _format_data_helper(data): # type: (Union[str, IO]) -> Union[Tuple[None, str], Tuple[Optional[str], IO, str]] """Helper for _format_data. @@ -190,7 +319,7 @@ def _format_data_helper(data): a string for a form-data request. :param data: The request field data. - :type data: str or file-like object. + :type data: str or file-like object. """ if hasattr(data, "read"): data = cast(IO, data) @@ -202,3 +331,32 @@ def _format_data_helper(data): pass return (data_name, data, "application/octet-stream") return (None, cast(str, data)) + +def _aiohttp_body_helper(response): + # pylint: disable=protected-access + # type: (PipelineTransportAioHttpTransportResponse) -> bytes + """Helper for body method of Aiohttp responses. + + Since aiohttp body methods need decompression work synchronously, + need to share thid code across old and new aiohttp transport responses + for backcompat. + + :rtype: bytes + """ + if response._content is None: + raise ValueError("Body is not available. Call async method load_body, or do your call with stream=False.") + if not response._decompress: + return response._content + enc = response.headers.get('Content-Encoding') + if not enc: + return response._content + enc = enc.lower() + if enc in ("gzip", "deflate"): + if response._decompressed_content: + return response._decompressed_content + import zlib + zlib_mode = 16 + zlib.MAX_WBITS if enc == "gzip" else zlib.MAX_WBITS + decompressor = zlib.decompressobj(wbits=zlib_mode) + response._decompressed_content = decompressor.decompress(response._content) + return response._decompressed_content + return response._content diff --git a/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared_async.py b/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared_async.py new file mode 100644 index 000000000000..bdd2fe276917 --- /dev/null +++ b/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared_async.py @@ -0,0 +1,63 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import asyncio +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING, Any +from ..pipeline import PipelineContext, PipelineRequest, PipelineResponse +from ..pipeline._tools_async import await_result as _await_result + +if TYPE_CHECKING: + from typing import List + from ..pipeline.policies import SansIOHTTPPolicy + +class _PartGenerator(AsyncIterator): + """Until parts is a real async iterator, wrap the sync call. + + :param parts: An iterable of parts + """ + + def __init__(self, response, default_http_response_type: Any) -> None: + self._response = response + self._parts = None + self._default_http_response_type = default_http_response_type + + async def _parse_response(self): + responses = self._response._get_raw_parts( # pylint: disable=protected-access + http_response_type=self._default_http_response_type + ) + if self._response.request.multipart_mixed_info: + policies = self._response.request.multipart_mixed_info[ + 1 + ] # type: List[SansIOHTTPPolicy] + + async def parse_responses(response): + http_request = response.request + context = PipelineContext(None) + pipeline_request = PipelineRequest(http_request, context) + pipeline_response = PipelineResponse( + http_request, response, context=context + ) + + for policy in policies: + await _await_result( + policy.on_response, pipeline_request, pipeline_response + ) + + # Not happy to make this code asyncio specific, but that's multipart only for now + # If we need trio and multipart, let's reinvesitgate that later + await asyncio.gather(*[parse_responses(res) for res in responses]) + + return responses + + async def __anext__(self): + if not self._parts: + self._parts = iter(await self._parse_response()) + + try: + return next(self._parts) + except StopIteration: + raise StopAsyncIteration() diff --git a/sdk/core/azure-core/tests/async_tests/test_authentication_async.py b/sdk/core/azure-core/tests/async_tests/test_authentication_async.py index 0c1362f93594..2120b8efb018 100644 --- a/sdk/core/azure-core/tests/async_tests/test_authentication_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_authentication_async.py @@ -12,7 +12,6 @@ from azure.core.exceptions import ServiceRequestError from azure.core.pipeline import AsyncPipeline from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, SansIOHTTPPolicy -from azure.core.pipeline.transport import HttpRequest import pytest pytestmark = pytest.mark.asyncio diff --git a/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py b/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py index 5165c523913d..82a5d21bb4d4 100644 --- a/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py @@ -27,7 +27,8 @@ import json import pickle import re -from utils import HTTP_REQUESTS, is_rest +from utils import HTTP_REQUESTS +from azure.core.pipeline._tools import is_rest import types import unittest try: @@ -50,7 +51,7 @@ from azure.core.polling.async_base_polling import ( AsyncLROBasePolling, ) - +from utils import ASYNCIO_REQUESTS_TRANSPORT_RESPONSES, request_and_responses_product, create_transport_response class SimpleResource: """An implementation of Python 3 SimpleNamespace. @@ -85,8 +86,9 @@ class BadEndpointError(Exception): CLIENT = AsyncPipelineClient("http://example.org") CLIENT.http_request_type = None +CLIENT.http_response_type = None async def mock_run(client_self, request, **kwargs): - return TestBasePolling.mock_update(client_self.http_request_type, request.url) + return TestBasePolling.mock_update(client_self.http_request_type, client_self.http_response_type, request.url) CLIENT._pipeline.run = types.MethodType(mock_run, CLIENT) @@ -164,14 +166,15 @@ def test_base_polling_continuation_token(client, polling_response): @pytest.mark.asyncio -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -async def test_post(async_pipeline_client_builder, deserialization_cb, http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(ASYNCIO_REQUESTS_TRANSPORT_RESPONSES)) +async def test_post(async_pipeline_client_builder, deserialization_cb, http_request, http_response): # Test POST LRO with both Location and Operation-Location # The initial response contains both Location and Operation-Location, a 202 and no Body initial_response = TestBasePolling.mock_send( http_request, + http_response, 'POST', 202, { @@ -187,6 +190,7 @@ async def send(request, **kwargs): if request.url == 'http://example.org/location': return TestBasePolling.mock_send( http_request, + http_response, 'GET', 200, body={'location_result': True} @@ -194,6 +198,7 @@ async def send(request, **kwargs): elif request.url == 'http://example.org/async_monitor': return TestBasePolling.mock_send( http_request, + http_response, 'GET', 200, body={'status': 'Succeeded'} @@ -220,6 +225,7 @@ async def send(request, **kwargs): if request.url == 'http://example.org/location': return TestBasePolling.mock_send( http_request, + http_response, 'GET', 200, body=None @@ -227,6 +233,7 @@ async def send(request, **kwargs): elif request.url == 'http://example.org/async_monitor': return TestBasePolling.mock_send( http_request, + http_response, 'GET', 200, body={'status': 'Succeeded'} @@ -246,14 +253,15 @@ async def send(request, **kwargs): @pytest.mark.asyncio -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -async def test_post_resource_location(async_pipeline_client_builder, deserialization_cb, http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(ASYNCIO_REQUESTS_TRANSPORT_RESPONSES)) +async def test_post_resource_location(async_pipeline_client_builder, deserialization_cb, http_request, http_response): # ResourceLocation # The initial response contains both Location and Operation-Location, a 202 and no Body initial_response = TestBasePolling.mock_send( http_request, + http_response, 'POST', 202, { @@ -268,6 +276,7 @@ async def send(request, **kwargs): if request.url == 'http://example.org/resource_location': return TestBasePolling.mock_send( http_request, + http_response, 'GET', 200, body={'location_result': True} @@ -275,6 +284,7 @@ async def send(request, **kwargs): elif request.url == 'http://example.org/async_monitor': return TestBasePolling.mock_send( http_request, + http_response, 'GET', 200, body={'status': 'Succeeded', 'resourceLocation': 'http://example.org/resource_location'} @@ -297,7 +307,7 @@ class TestBasePolling(object): convert = re.compile('([a-z0-9])([A-Z])') @staticmethod - def mock_send(http_request, method, status, headers=None, body=RESPONSE_BODY): + def mock_send(http_request, http_response, method, status, headers=None, body=RESPONSE_BODY): if headers is None: headers = {} response = Response() @@ -331,18 +341,17 @@ def mock_send(http_request, method, status, headers=None, body=RESPONSE_BODY): None, # form_content None # stream_content ) - + response = create_transport_response(http_response, request, response) + if is_rest(http_response): + response.body() return PipelineResponse( request, - AsyncioRequestsTransportResponse( - request, - response, - ), + response, None # context ) @staticmethod - def mock_update(http_request, url, headers=None): + def mock_update(http_request, http_response, url, headers=None): response = Response() response._content_consumed = True response.request = mock.create_autospec(Request) @@ -379,12 +388,13 @@ def mock_update(http_request, url, headers=None): response.request.url, ) + response = create_transport_response(http_response, request, response) + if is_rest(http_response): + response.body() + return PipelineResponse( request, - AsyncioRequestsTransportResponse( - request, - response, - ), + response, None # context ) @@ -419,13 +429,14 @@ def mock_deserialization_no_body(pipeline_response): return None @pytest.mark.asyncio -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -async def test_long_running_put(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(ASYNCIO_REQUESTS_TRANSPORT_RESPONSES)) +async def test_long_running_put(http_request, http_response): #TODO: Test custom header field CLIENT.http_request_type = http_request + CLIENT.http_response_type = http_response # Test throw on non LRO related status code response = TestBasePolling.mock_send( - http_request, 'PUT', 1000, {} + http_request, http_response, 'PUT', 1000, {} ) with pytest.raises(HttpResponseError): await async_poller(CLIENT, response, @@ -439,6 +450,7 @@ async def test_long_running_put(http_request): } response = TestBasePolling.mock_send( http_request, + http_response, 'PUT', 201, {}, response_body ) @@ -455,6 +467,7 @@ def no_update_allowed(url, headers=None): # Test polling from operation-location header response = TestBasePolling.mock_send( http_request, + http_response, 'PUT', 201, {'operation-location': ASYNC_URL}) polling_method = AsyncLROBasePolling(0) @@ -467,6 +480,7 @@ def no_update_allowed(url, headers=None): # Test polling location header response = TestBasePolling.mock_send( http_request, + http_response, 'PUT', 201, {'location': LOCATION_URL}) polling_method = AsyncLROBasePolling(0) @@ -480,6 +494,7 @@ def no_update_allowed(url, headers=None): response_body = {} # Empty will raise response = TestBasePolling.mock_send( http_request, + http_response, 'PUT', 201, {'location': LOCATION_URL}, response_body) polling_method = AsyncLROBasePolling(0) @@ -492,6 +507,7 @@ def no_update_allowed(url, headers=None): # Test fail to poll from operation-location header response = TestBasePolling.mock_send( http_request, + http_response, 'PUT', 201, {'operation-location': ERROR}) with pytest.raises(BadEndpointError): @@ -502,6 +518,7 @@ def no_update_allowed(url, headers=None): # Test fail to poll from location header response = TestBasePolling.mock_send( http_request, + http_response, 'PUT', 201, {'location': ERROR}) with pytest.raises(BadEndpointError): @@ -510,12 +527,14 @@ def no_update_allowed(url, headers=None): AsyncLROBasePolling(0)) @pytest.mark.asyncio -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -async def test_long_running_patch(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(ASYNCIO_REQUESTS_TRANSPORT_RESPONSES)) +async def test_long_running_patch(http_request, http_response): CLIENT.http_request_type = http_request + CLIENT.http_response_type = http_response # Test polling from location header response = TestBasePolling.mock_send( http_request, + http_response, 'PATCH', 202, {'location': LOCATION_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -529,6 +548,7 @@ async def test_long_running_patch(http_request): # Test polling from operation-location header response = TestBasePolling.mock_send( http_request, + http_response, 'PATCH', 202, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -542,6 +562,7 @@ async def test_long_running_patch(http_request): # Test polling from location header response = TestBasePolling.mock_send( http_request, + http_response, 'PATCH', 200, {'location': LOCATION_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -555,6 +576,7 @@ async def test_long_running_patch(http_request): # Test polling from operation-location header response = TestBasePolling.mock_send( http_request, + http_response, 'PATCH', 200, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -568,6 +590,7 @@ async def test_long_running_patch(http_request): # Test fail to poll from operation-location header response = TestBasePolling.mock_send( http_request, + http_response, 'PATCH', 202, {'operation-location': ERROR}) with pytest.raises(BadEndpointError): @@ -578,6 +601,7 @@ async def test_long_running_patch(http_request): # Test fail to poll from location header response = TestBasePolling.mock_send( http_request, + http_response, 'PATCH', 202, {'location': ERROR}) with pytest.raises(BadEndpointError): @@ -586,12 +610,14 @@ async def test_long_running_patch(http_request): AsyncLROBasePolling(0)) @pytest.mark.asyncio -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -async def test_long_running_delete(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(ASYNCIO_REQUESTS_TRANSPORT_RESPONSES)) +async def test_long_running_delete(http_request, http_response): # Test polling from operation-location header CLIENT.http_request_type = http_request + CLIENT.http_response_type = http_response response = TestBasePolling.mock_send( http_request, + http_response, 'DELETE', 202, {'operation-location': ASYNC_URL}, body="" @@ -604,12 +630,14 @@ async def test_long_running_delete(http_request): assert polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader is None @pytest.mark.asyncio -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -async def test_long_running_post(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(ASYNCIO_REQUESTS_TRANSPORT_RESPONSES)) +async def test_long_running_post(http_request, http_response): CLIENT.http_request_type = http_request + CLIENT.http_response_type = http_response # Test polling from operation-location header response = TestBasePolling.mock_send( http_request, + http_response, 'POST', 201, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -622,6 +650,7 @@ async def test_long_running_post(http_request): # Test polling from operation-location header response = TestBasePolling.mock_send( http_request, + http_response, 'POST', 202, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -634,6 +663,7 @@ async def test_long_running_post(http_request): # Test polling from location header response = TestBasePolling.mock_send( http_request, + http_response, 'POST', 202, {'location': LOCATION_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -647,6 +677,7 @@ async def test_long_running_post(http_request): # Test fail to poll from operation-location header response = TestBasePolling.mock_send( http_request, + http_response, 'POST', 202, {'operation-location': ERROR}) with pytest.raises(BadEndpointError): @@ -657,6 +688,7 @@ async def test_long_running_post(http_request): # Test fail to poll from location header response = TestBasePolling.mock_send( http_request, + http_response, 'POST', 202, {'location': ERROR}) with pytest.raises(BadEndpointError): @@ -665,15 +697,17 @@ async def test_long_running_post(http_request): AsyncLROBasePolling(0)) @pytest.mark.asyncio -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -async def test_long_running_negative(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(ASYNCIO_REQUESTS_TRANSPORT_RESPONSES)) +async def test_long_running_negative(http_request, http_response): global LOCATION_BODY global POLLING_STATUS CLIENT.http_request_type = http_request + CLIENT.http_response_type = http_response # Test LRO PUT throws for invalid json LOCATION_BODY = '{' response = TestBasePolling.mock_send( http_request, + http_response, 'POST', 202, {'location': LOCATION_URL}) poll = async_poller( @@ -688,6 +722,7 @@ async def test_long_running_negative(http_request): LOCATION_BODY = '{\'"}' response = TestBasePolling.mock_send( http_request, + http_response, 'POST', 202, {'location': LOCATION_URL}) poll = async_poller(CLIENT, response, @@ -700,6 +735,7 @@ async def test_long_running_negative(http_request): POLLING_STATUS = 203 response = TestBasePolling.mock_send( http_request, + http_response, 'POST', 202, {'location': LOCATION_URL}) poll = async_poller(CLIENT, response, diff --git a/sdk/core/azure-core/tests/async_tests/test_basic_transport_async.py b/sdk/core/azure-core/tests/async_tests/test_basic_transport_async.py index 5632a2097131..6f6b61744b02 100644 --- a/sdk/core/azure-core/tests/async_tests/test_basic_transport_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_basic_transport_async.py @@ -3,11 +3,12 @@ # Licensed under the MIT License. See LICENSE.txt in the project root for # license information. # ------------------------------------------------------------------------- -from azure.core.pipeline.transport import AsyncHttpResponse, AsyncHttpTransport, AioHttpTransport +from azure.core.pipeline.transport import AsyncHttpResponse as PipelineTransportAsyncHttpResponse, AsyncHttpTransport, AioHttpTransport +from azure.core.rest._http_response_impl_async import AsyncHttpResponseImpl as RestAsyncHttpResponse from azure.core.pipeline.policies import HeadersPolicy from azure.core.pipeline import AsyncPipeline from azure.core.exceptions import HttpResponseError -from utils import HTTP_REQUESTS +from utils import HTTP_REQUESTS, request_and_responses_product import pytest @@ -23,15 +24,39 @@ async def close(self): pass async def send(self, request, **kwargs): pass -class MockResponse(AsyncHttpResponse): +class PipelineTransportMockResponse(PipelineTransportAsyncHttpResponse): def __init__(self, request, body, content_type): - super(MockResponse, self).__init__(request, None) + super().__init__(request, None) self._body = body self.content_type = content_type def body(self): return self._body +class RestMockResponse(RestAsyncHttpResponse): + def __init__(self, request, body, content_type): + super(RestMockResponse, self).__init__( + request=request, + internal_response=None, + content_type=content_type, + block_size=None, + status_code=200, + reason="OK", + headers={}, + stream_download_generator=None, + ) + # the impl takes in a lot more kwargs. It's not public and is a + # helper implementation shared across our azure core transport responses + self._content = body + + def body(self): + return self._content + + @property + def content(self): + return self._content + +MOCK_RESPONSES = [PipelineTransportMockResponse, RestMockResponse] @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) @@ -431,8 +456,8 @@ async def test_multipart_send_with_combination_changeset_middle(http_request): @pytest.mark.asyncio -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -async def test_multipart_receive(http_request): +@pytest.mark.parametrize("http_request,mock_response", request_and_responses_product(MOCK_RESPONSES)) +async def test_multipart_receive(http_request, mock_response): class ResponsePolicy(object): def on_response(self, request, response): @@ -481,7 +506,7 @@ async def on_response(self, request, response): "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--" ) - response = MockResponse( + response = mock_response( request, body_as_str.encode('ascii'), "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -505,8 +530,8 @@ async def on_response(self, request, response): @pytest.mark.asyncio -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -async def test_multipart_receive_with_one_changeset(http_request): +@pytest.mark.parametrize("http_request,mock_response", request_and_responses_product(MOCK_RESPONSES)) +async def test_multipart_receive_with_one_changeset(http_request, mock_response): changeset = http_request("", "") changeset.set_multipart_mixed( http_request("DELETE", "/container0/blob0"), @@ -544,7 +569,7 @@ async def test_multipart_receive_with_one_changeset(http_request): b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -560,8 +585,8 @@ async def test_multipart_receive_with_one_changeset(http_request): @pytest.mark.asyncio -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -async def test_multipart_receive_with_multiple_changesets(http_request): +@pytest.mark.parametrize("http_request,mock_response", request_and_responses_product(MOCK_RESPONSES)) +async def test_multipart_receive_with_multiple_changesets(http_request, mock_response): changeset1 = http_request("", "") changeset1.set_multipart_mixed( @@ -630,7 +655,7 @@ async def test_multipart_receive_with_multiple_changesets(http_request): b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -647,8 +672,8 @@ async def test_multipart_receive_with_multiple_changesets(http_request): @pytest.mark.asyncio -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -async def test_multipart_receive_with_combination_changeset_first(http_request): +@pytest.mark.parametrize("http_request,mock_response", request_and_responses_product(MOCK_RESPONSES)) +async def test_multipart_receive_with_combination_changeset_first(http_request, mock_response): changeset = http_request("", "") changeset.set_multipart_mixed( @@ -697,7 +722,7 @@ async def test_multipart_receive_with_combination_changeset_first(http_request): b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -711,21 +736,23 @@ async def test_multipart_receive_with_combination_changeset_first(http_request): assert parts[1].status_code == 202 assert parts[2].status_code == 404 -def test_raise_for_status_bad_response(): - response = MockResponse(request=None, body=None, content_type=None) +@pytest.mark.parametrize("mock_response", MOCK_RESPONSES) +def test_raise_for_status_bad_response(mock_response): + response = mock_response(request=None, body=None, content_type=None) response.status_code = 400 with pytest.raises(HttpResponseError): response.raise_for_status() -def test_raise_for_status_good_response(): - response = MockResponse(request=None, body=None, content_type=None) +@pytest.mark.parametrize("mock_response", MOCK_RESPONSES) +def test_raise_for_status_good_response(mock_response): + response = mock_response(request=None, body=None, content_type=None) response.status_code = 200 response.raise_for_status() @pytest.mark.asyncio -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -async def test_multipart_receive_with_combination_changeset_middle(http_request): +@pytest.mark.parametrize("http_request,mock_response", request_and_responses_product(MOCK_RESPONSES)) +async def test_multipart_receive_with_combination_changeset_middle(http_request, mock_response): changeset = http_request("", "") changeset.set_multipart_mixed(http_request("DELETE", "/container1/blob1")) @@ -775,7 +802,7 @@ async def test_multipart_receive_with_combination_changeset_middle(http_request) b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -791,8 +818,8 @@ async def test_multipart_receive_with_combination_changeset_middle(http_request) @pytest.mark.asyncio -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -async def test_multipart_receive_with_combination_changeset_last(http_request): +@pytest.mark.parametrize("http_request,mock_response", request_and_responses_product(MOCK_RESPONSES)) +async def test_multipart_receive_with_combination_changeset_last(http_request, mock_response): changeset = http_request("", "") changeset.set_multipart_mixed( @@ -842,7 +869,7 @@ async def test_multipart_receive_with_combination_changeset_last(http_request): b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -858,8 +885,8 @@ async def test_multipart_receive_with_combination_changeset_last(http_request): @pytest.mark.asyncio -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -async def test_multipart_receive_with_bom(http_request): +@pytest.mark.parametrize("http_request,mock_response", request_and_responses_product(MOCK_RESPONSES)) +async def test_multipart_receive_with_bom(http_request, mock_response): req0 = http_request("DELETE", "/container0/blob0") @@ -881,7 +908,7 @@ async def test_multipart_receive_with_bom(http_request): b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--" ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -898,8 +925,8 @@ async def test_multipart_receive_with_bom(http_request): @pytest.mark.asyncio -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -async def test_recursive_multipart_receive(http_request): +@pytest.mark.parametrize("http_request,mock_response", request_and_responses_product(MOCK_RESPONSES)) +async def test_recursive_multipart_receive(http_request, mock_response): req0 = http_request("DELETE", "/container0/blob0") internal_req0 = http_request("DELETE", "/container0/blob0") req0.set_multipart_mixed(internal_req0) @@ -930,7 +957,7 @@ async def test_recursive_multipart_receive(http_request): "--batchresponse_8d5f5bcd-2cb5-44bb-91b5-e9a722e68cb6--" ).format(internal_body_as_str) - response = MockResponse( + response = mock_response( request, body_as_str.encode('ascii'), "multipart/mixed; boundary=batchresponse_8d5f5bcd-2cb5-44bb-91b5-e9a722e68cb6" diff --git a/sdk/core/azure-core/tests/async_tests/test_http_logging_policy_async.py b/sdk/core/azure-core/tests/async_tests/test_http_logging_policy_async.py index ab8d4ca72b37..8d2b52360e50 100644 --- a/sdk/core/azure-core/tests/async_tests/test_http_logging_policy_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_http_logging_policy_async.py @@ -14,16 +14,13 @@ PipelineRequest, PipelineContext ) -from azure.core.pipeline.transport import ( - HttpResponse, -) from azure.core.pipeline.policies import ( HttpLoggingPolicy, ) -from utils import HTTP_REQUESTS +from utils import HTTP_RESPONSES, request_and_responses_product, create_http_response -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_http_logger(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) +def test_http_logger(http_request, http_response): class MockHandler(logging.Handler): def __init__(self): @@ -42,7 +39,7 @@ def emit(self, record): policy = HttpLoggingPolicy(logger=logger) universal_request = http_request('GET', 'http://localhost/') - http_response = HttpResponse(universal_request, None) + http_response = create_http_response(http_response, universal_request, None) http_response.status_code = 202 request = PipelineRequest(universal_request, PipelineContext(None)) @@ -144,8 +141,8 @@ def emit(self, record): mock_handler.reset() -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_http_logger_operation_level(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) +def test_http_logger_operation_level(http_request, http_response): class MockHandler(logging.Handler): def __init__(self): @@ -165,7 +162,7 @@ def emit(self, record): kwargs={'logger': logger} universal_request = http_request('GET', 'http://localhost/') - http_response = HttpResponse(universal_request, None) + http_response = create_http_response(http_response, universal_request, None) http_response.status_code = 202 request = PipelineRequest(universal_request, PipelineContext(None, **kwargs)) @@ -221,8 +218,8 @@ def emit(self, record): mock_handler.reset() -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_http_logger_with_body(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) +def test_http_logger_with_body(http_request, http_response): class MockHandler(logging.Handler): def __init__(self): @@ -242,7 +239,7 @@ def emit(self, record): universal_request = http_request('GET', 'http://localhost/') universal_request.body = "testbody" - http_response = HttpResponse(universal_request, None) + http_response = create_http_response(http_response, universal_request, None) http_response.status_code = 202 request = PipelineRequest(universal_request, PipelineContext(None)) @@ -264,9 +261,9 @@ def emit(self, record): mock_handler.reset() -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) @pytest.mark.skipif(sys.version_info < (3, 6), reason="types.AsyncGeneratorType does not exist in 3.5") -def test_http_logger_with_generator_body(http_request): +def test_http_logger_with_generator_body(http_request, http_response): class MockHandler(logging.Handler): def __init__(self): @@ -288,7 +285,7 @@ def emit(self, record): mock = Mock() mock.__class__ = types.AsyncGeneratorType universal_request.body = mock - http_response = HttpResponse(universal_request, None) + http_response = create_http_response(http_response, universal_request, None) http_response.status_code = 202 request = PipelineRequest(universal_request, PipelineContext(None)) diff --git a/sdk/core/azure-core/tests/async_tests/test_request_asyncio.py b/sdk/core/azure-core/tests/async_tests/test_request_asyncio.py index 8a8199e65429..189a09eaccb4 100644 --- a/sdk/core/azure-core/tests/async_tests/test_request_asyncio.py +++ b/sdk/core/azure-core/tests/async_tests/test_request_asyncio.py @@ -7,6 +7,7 @@ from azure.core.pipeline.transport import AsyncioRequestsTransport from utils import HTTP_REQUESTS +from azure.core.pipeline._tools import is_rest import pytest @@ -29,6 +30,8 @@ async def __anext__(self): async with AsyncioRequestsTransport() as transport: req = http_request('GET', 'http://localhost:{}/basic/anything'.format(port), data=AsyncGen()) response = await transport.send(req) + if is_rest(http_request): + assert is_rest(response) assert json.loads(response.text())['data'] == "azerty" @pytest.mark.asyncio @@ -37,4 +40,6 @@ async def test_send_data(port, http_request): async with AsyncioRequestsTransport() as transport: req = http_request('PUT', 'http://localhost:{}/basic/anything'.format(port), data=b"azerty") response = await transport.send(req) + if is_rest(http_request): + assert is_rest(response) assert json.loads(response.text())['data'] == "azerty" diff --git a/sdk/core/azure-core/tests/async_tests/test_request_trio.py b/sdk/core/azure-core/tests/async_tests/test_request_trio.py index 7fafa0d41c28..092a550cb67b 100644 --- a/sdk/core/azure-core/tests/async_tests/test_request_trio.py +++ b/sdk/core/azure-core/tests/async_tests/test_request_trio.py @@ -7,6 +7,7 @@ from azure.core.pipeline.transport import TrioRequestsTransport from utils import HTTP_REQUESTS +from azure.core.pipeline._tools import is_rest import pytest @@ -30,6 +31,8 @@ async def __anext__(self): async with TrioRequestsTransport() as transport: req = http_request('GET', 'http://localhost:{}/basic/anything'.format(port), data=AsyncGen()) response = await transport.send(req) + if is_rest(http_request): + assert is_rest(response) assert json.loads(response.text())['data'] == "azerty" @pytest.mark.trio @@ -38,5 +41,6 @@ async def test_send_data(port, http_request): async with TrioRequestsTransport() as transport: req = http_request('PUT', 'http://localhost:{}/basic/anything'.format(port), data=b"azerty") response = await transport.send(req) - + if is_rest(http_request): + assert is_rest(response) assert json.loads(response.text())['data'] == "azerty" \ No newline at end of file diff --git a/sdk/core/azure-core/tests/async_tests/test_rest_asyncio_transport.py b/sdk/core/azure-core/tests/async_tests/test_rest_asyncio_transport.py index 11a2ca3f52ca..98941fbcd8d5 100644 --- a/sdk/core/azure-core/tests/async_tests/test_rest_asyncio_transport.py +++ b/sdk/core/azure-core/tests/async_tests/test_rest_asyncio_transport.py @@ -50,4 +50,5 @@ async def test_readonly(port): response.raise_for_status() assert isinstance(response, RestAsyncioRequestsTransportResponse) - readonly_checks(response) + from azure.core.pipeline.transport import AsyncioRequestsTransportResponse + readonly_checks(response, old_response_class=AsyncioRequestsTransportResponse) diff --git a/sdk/core/azure-core/tests/async_tests/test_rest_http_response_async.py b/sdk/core/azure-core/tests/async_tests/test_rest_http_response_async.py index 7239dcd7a547..ebbd4079e6ea 100644 --- a/sdk/core/azure-core/tests/async_tests/test_rest_http_response_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_rest_http_response_async.py @@ -8,6 +8,7 @@ # Thank you httpx for your wonderful tests! import io import pytest +import zlib from azure.core.rest import HttpRequest, AsyncHttpResponse from azure.core.rest._aiohttp import RestAioHttpTransportResponse from azure.core.exceptions import HttpResponseError @@ -171,18 +172,6 @@ async def test_response_no_charset_with_iso_8859_1_content(send_request): assert response.text() == "Accented: �sterreich" assert response.encoding is None -# NOTE: aiohttp isn't liking this -# @pytest.mark.asyncio -# async def test_response_set_explicit_encoding(send_request): -# response = await send_request( -# request=HttpRequest("GET", "/encoding/latin-1-with-utf-8"), -# ) -# assert response.headers["Content-Type"] == "text/plain; charset=utf-8" -# response.encoding = "latin-1" -# await response.read() -# assert response.text() == "Latin 1: ÿ" -# assert response.encoding == "latin-1" - @pytest.mark.asyncio async def test_json(send_request): response = await send_request( @@ -309,4 +298,5 @@ async def test_readonly(send_request): response = await send_request(HttpRequest("GET", "/health")) assert isinstance(response, RestAioHttpTransportResponse) - readonly_checks(response) + from azure.core.pipeline.transport import AioHttpTransportResponse + readonly_checks(response, old_response_class=AioHttpTransportResponse) diff --git a/sdk/core/azure-core/tests/async_tests/test_rest_response_backcompat_async.py b/sdk/core/azure-core/tests/async_tests/test_rest_response_backcompat_async.py new file mode 100644 index 000000000000..80b5a8d2911a --- /dev/null +++ b/sdk/core/azure-core/tests/async_tests/test_rest_response_backcompat_async.py @@ -0,0 +1,321 @@ +# -*- coding: utf-8 -*- +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import pytest +from azure.core.pipeline.transport import HttpRequest as PipelineTransportHttpRequest +from azure.core.rest import HttpRequest as RestHttpRequest +from azure.core.pipeline import Pipeline +from azure.core.pipeline.transport import AioHttpTransport, AsyncioRequestsTransport, TrioRequestsTransport +from rest_client_async import AsyncTestRestClient + +TRANSPORTS = [AioHttpTransport, AsyncioRequestsTransport] + +@pytest.fixture + +def old_request(port): + return PipelineTransportHttpRequest("GET", "http://localhost:{}/streams/basic".format(port)) + +@pytest.fixture +@pytest.mark.asyncio +async def get_old_response(old_request): + async def _callback(transport, **kwargs): + async with transport() as sender: + return await sender.send(old_request, **kwargs) + return _callback + +@pytest.fixture +@pytest.mark.trio +async def get_old_response_trio(old_request): + async def _callback(**kwargs): + async with TrioRequestsTransport() as sender: + return await sender.send(old_request, **kwargs) + return _callback + +@pytest.fixture +def new_request(port): + return RestHttpRequest("GET", "http://localhost:{}/streams/basic".format(port)) + +@pytest.fixture +@pytest.mark.asyncio +async def get_new_response(new_request): + async def _callback(transport, **kwargs): + async with transport() as sender: + return await sender.send(new_request, **kwargs) + return _callback + +@pytest.fixture +@pytest.mark.trio +async def get_new_response_trio(new_request): + async def _callback(**kwargs): + async with TrioRequestsTransport() as sender: + return await sender.send(new_request, **kwargs) + return _callback + +def _test_response_attr_parity(old_response, new_response): + for attr in dir(old_response): + if not attr[0] == "_": + # if not a private attr, we want partiy + assert hasattr(new_response, attr) + +@pytest.mark.asyncio +@pytest.mark.parametrize("transport", TRANSPORTS) +async def test_response_attr_parity(get_old_response, get_new_response, transport): + old_response = await get_old_response(transport) + new_response = await get_new_response(transport) + _test_response_attr_parity(old_response, new_response) + +@pytest.mark.trio +async def test_response_attr_parity_trio(get_old_response_trio, get_new_response_trio): + old_response = await get_old_response_trio() + new_response = await get_new_response_trio() + _test_response_attr_parity(old_response, new_response) + +def _test_response_set_attrs(old_response, new_response): + for attr in dir(old_response): + if attr[0] == "_": + continue + try: + # if we can set it on the old request, we want to + # be able to set it on the new + setattr(old_response, attr, "foo") + except: + pass + else: + setattr(new_response, attr, "foo") + assert getattr(old_response, attr) == getattr(new_response, attr) == "foo" + +@pytest.mark.asyncio +@pytest.mark.parametrize("transport", TRANSPORTS) +async def test_response_set_attrs(get_old_response, get_new_response, transport): + old_response = await get_old_response(transport) + new_response = await get_new_response(transport) + _test_response_set_attrs(old_response, new_response) + +@pytest.mark.trio +async def test_response_set_attrs_trio(get_old_response_trio, get_new_response_trio): + old_response = await get_old_response_trio() + new_response = await get_new_response_trio() + _test_response_set_attrs(old_response, new_response) + +def _test_response_block_size(old_response, new_response): + assert old_response.block_size == new_response.block_size == 4096 + old_response.block_size = 500 + new_response.block_size = 500 + assert old_response.block_size == new_response.block_size == 500 + +@pytest.mark.asyncio +@pytest.mark.parametrize("transport", TRANSPORTS) +async def test_response_block_size(get_old_response, get_new_response, transport): + old_response = await get_old_response(transport) + new_response = await get_new_response(transport) + _test_response_block_size(old_response, new_response) + +@pytest.mark.trio +async def test_response_block_size_trio(get_old_response_trio, get_new_response_trio): + old_response = await get_old_response_trio() + new_response = await get_new_response_trio() + _test_response_block_size(old_response, new_response) + +@pytest.mark.asyncio +@pytest.mark.parametrize("transport", TRANSPORTS) +async def test_response_body(get_old_response, get_new_response, transport): + old_response = await get_old_response(transport) + new_response = await get_new_response(transport) + assert old_response.body() == new_response.body() == b"Hello, world!" + +@pytest.mark.trio +async def test_response_body_trio(get_old_response_trio, get_new_response_trio): + old_response = await get_old_response_trio() + new_response = await get_new_response_trio() + assert old_response.body() == new_response.body() == b"Hello, world!" + +def _test_response_internal_response(old_response, new_response, port): + assert str(old_response.internal_response.url) == str(new_response.internal_response.url) == "http://localhost:{}/streams/basic".format(port) + old_response.internal_response = "foo" + new_response.internal_response = "foo" + assert old_response.internal_response == new_response.internal_response == "foo" + +@pytest.mark.asyncio +@pytest.mark.parametrize("transport", TRANSPORTS) +async def test_response_internal_response(get_old_response, get_new_response, transport, port): + old_response = await get_old_response(transport) + new_response = await get_new_response(transport) + _test_response_internal_response(old_response, new_response, port) + +@pytest.mark.trio +async def test_response_internal_response_trio(get_old_response_trio, get_new_response_trio, port): + old_response = await get_old_response_trio() + new_response = await get_new_response_trio() + _test_response_internal_response(old_response, new_response, port) + +@pytest.mark.asyncio +@pytest.mark.parametrize("transport", TRANSPORTS) +async def test_response_stream_download(get_old_response, get_new_response, transport): + old_response = await get_old_response(transport, stream=True) + new_response = await get_new_response(transport, stream=True) + pipeline = Pipeline(transport()) + old_string = b"".join([part async for part in old_response.stream_download(pipeline=pipeline)]) + new_string = b"".join([part async for part in new_response.stream_download(pipeline=pipeline)]) + + # aiohttp can be flaky for both old and new responses, so since we're just checking backcompat here + # using in instead of equals + assert old_string in b"Hello, world!" + assert new_string in b"Hello, world!" + +@pytest.mark.trio +async def test_response_stream_download_trio(get_old_response_trio, get_new_response_trio): + old_response = await get_old_response_trio(stream=True) + new_response = await get_new_response_trio(stream=True) + pipeline = Pipeline(TrioRequestsTransport()) + old_string = b"".join([part async for part in old_response.stream_download(pipeline=pipeline)]) + new_string = b"".join([part async for part in new_response.stream_download(pipeline=pipeline)]) + assert old_string == new_string == b"Hello, world!" + +def _test_response_request(old_response, new_response, port): + assert old_response.request.url == new_response.request.url == "http://localhost:{}/streams/basic".format(port) + old_response.request = "foo" + new_response.request = "foo" + assert old_response.request == new_response.request == "foo" + +@pytest.mark.asyncio +@pytest.mark.parametrize("transport", TRANSPORTS) +async def test_response_request(get_old_response, get_new_response, port, transport): + old_response = await get_old_response(transport) + new_response = await get_new_response(transport) + _test_response_request(old_response, new_response, port) + +@pytest.mark.trio +async def test_response_request_trio(get_old_response_trio, get_new_response_trio, port): + old_response = await get_old_response_trio() + new_response = await get_new_response_trio() + _test_response_request(old_response, new_response, port) + +def _test_response_status_code(old_response, new_response): + assert old_response.status_code == new_response.status_code == 200 + old_response.status_code = 202 + new_response.status_code = 202 + assert old_response.status_code == new_response.status_code == 202 + +@pytest.mark.asyncio +@pytest.mark.parametrize("transport", TRANSPORTS) +async def test_response_status_code(get_old_response, get_new_response, transport): + old_response = await get_old_response(transport) + new_response = await get_new_response(transport) + _test_response_status_code(old_response, new_response) + +@pytest.mark.trio +async def test_response_status_code_trio(get_old_response_trio, get_new_response_trio): + old_response = await get_old_response_trio() + new_response = await get_new_response_trio() + _test_response_status_code(old_response, new_response) + +def _test_response_headers(old_response, new_response): + assert set(old_response.headers.keys()) == set(new_response.headers.keys()) == set(["Content-Type", "Connection", "Server", "Date"]) + old_response.headers = {"Hello": "world!"} + new_response.headers = {"Hello": "world!"} + assert old_response.headers == new_response.headers == {"Hello": "world!"} + +@pytest.mark.asyncio +@pytest.mark.parametrize("transport", TRANSPORTS) +async def test_response_headers(get_old_response, get_new_response, transport): + old_response = await get_old_response(transport) + new_response = await get_new_response(transport) + _test_response_headers(old_response, new_response) + +@pytest.mark.trio +async def test_response_headers_trio(get_old_response_trio, get_new_response_trio): + old_response = await get_old_response_trio() + new_response = await get_new_response_trio() + _test_response_headers(old_response, new_response) + +def _test_response_reason(old_response, new_response): + assert old_response.reason == new_response.reason == "OK" + old_response.reason = "Not OK" + new_response.reason = "Not OK" + assert old_response.reason == new_response.reason == "Not OK" + +@pytest.mark.asyncio +@pytest.mark.parametrize("transport", TRANSPORTS) +async def test_response_reason(get_old_response, get_new_response, transport): + old_response = await get_old_response(transport) + new_response = await get_new_response(transport) + _test_response_reason(old_response, new_response) + +@pytest.mark.trio +async def test_response_reason_trio(get_old_response_trio, get_new_response_trio): + old_response = await get_old_response_trio() + new_response = await get_new_response_trio() + _test_response_reason(old_response, new_response) + +def _test_response_content_type(old_response, new_response): + assert old_response.content_type == new_response.content_type == "text/html; charset=utf-8" + old_response.content_type = "application/json" + new_response.content_type = "application/json" + assert old_response.content_type == new_response.content_type == "application/json" + +@pytest.mark.asyncio +@pytest.mark.parametrize("transport", TRANSPORTS) +async def test_response_content_type(get_old_response, get_new_response, transport): + old_response = await get_old_response(transport) + new_response = await get_new_response(transport) + _test_response_content_type(old_response, new_response) + +@pytest.mark.trio +async def test_response_content_type_trio(get_old_response_trio, get_new_response_trio): + old_response = await get_old_response_trio() + new_response = await get_new_response_trio() + _test_response_content_type(old_response, new_response) + +def _create_multiapart_request(http_request_class): + class ResponsePolicy(object): + def on_request(self, *args): + return + + def on_response(self, request, response): + response.http_response.headers['x-ms-fun'] = 'true' + + class AsyncResponsePolicy(object): + def on_request(self, *args): + return + + async def on_response(self, request, response): + response.http_response.headers['x-ms-async-fun'] = 'true' + + req0 = http_request_class("DELETE", "/container0/blob0") + req1 = http_request_class("DELETE", "/container1/blob1") + request = http_request_class("POST", "/multipart/request") + request.set_multipart_mixed(req0, req1, policies=[ResponsePolicy(), AsyncResponsePolicy()]) + return request + +async def _test_parts(response): + # hack the content type + parts = [p async for p in response.parts()] + assert len(parts) == 2 + + parts0 = parts[0] + assert parts0.status_code == 202 + assert parts0.headers['x-ms-fun'] == 'true' + assert parts0.headers['x-ms-async-fun'] == 'true' + + parts1 = parts[1] + assert parts1.status_code == 404 + assert parts1.headers['x-ms-fun'] == 'true' + assert parts1.headers['x-ms-async-fun'] == 'true' + +@pytest.mark.asyncio +@pytest.mark.parametrize("transport", TRANSPORTS) +async def test_response_parts(port, transport): + # there's no support for trio + multipart rn + old_request = _create_multiapart_request(PipelineTransportHttpRequest) + new_request = _create_multiapart_request(RestHttpRequest) + old_response = await AsyncTestRestClient(port, transport=transport()).send_request(old_request, stream=True) + new_response = await AsyncTestRestClient(port, transport=transport()).send_request(new_request, stream=True) + if hasattr(old_response, "load_body"): + # only aiohttp has this attr + await old_response.load_body() + await new_response.load_body() + await _test_parts(old_response) + await _test_parts(new_response) diff --git a/sdk/core/azure-core/tests/async_tests/test_rest_trio_transport.py b/sdk/core/azure-core/tests/async_tests/test_rest_trio_transport.py index 18eb6ad7a84b..8213a37d09ea 100644 --- a/sdk/core/azure-core/tests/async_tests/test_rest_trio_transport.py +++ b/sdk/core/azure-core/tests/async_tests/test_rest_trio_transport.py @@ -51,4 +51,5 @@ async def test_readonly(port): response.raise_for_status() assert isinstance(response, RestTrioRequestsTransportResponse) - readonly_checks(response) + from azure.core.pipeline.transport import TrioRequestsTransportResponse + readonly_checks(response, old_response_class=TrioRequestsTransportResponse) diff --git a/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py b/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py index f213b92dbbf7..f9be72d1a9da 100644 --- a/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py @@ -4,7 +4,6 @@ # ------------------------------------ import requests from azure.core.pipeline.transport import ( - AsyncHttpResponse, AsyncHttpTransport, AsyncioRequestsTransportResponse, AioHttpTransport, @@ -13,11 +12,11 @@ from azure.core.pipeline.transport._aiohttp import AioHttpStreamDownloadGenerator from unittest import mock import pytest -from utils import HTTP_REQUESTS +from utils import request_and_responses_product, ASYNC_HTTP_RESPONSES, create_http_response @pytest.mark.asyncio -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -async def test_connection_error_response(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(ASYNC_HTTP_RESPONSES)) +async def test_connection_error_response(http_request, http_response): class MockSession(object): def __init__(self): self.auto_decompress = True @@ -40,7 +39,7 @@ async def open(self): async def send(self, request, **kwargs): request = http_request('GET', 'http://localhost/') - response = AsyncHttpResponse(request, None) + response = create_http_response(http_response, request, None) response.status_code = 200 return response @@ -68,7 +67,7 @@ async def __call__(self, *args, **kwargs): http_request = http_request('GET', 'http://localhost/') pipeline = AsyncPipeline(MockTransport()) - http_response = AsyncHttpResponse(http_request, None) + http_response = create_http_response(http_response, http_request, None) http_response.internal_response = MockInternalResponse() stream = AioHttpStreamDownloadGenerator(pipeline, http_response, decompress=False) with mock.patch('asyncio.sleep', new_callable=AsyncMock): @@ -76,7 +75,8 @@ async def __call__(self, *args, **kwargs): await stream.__anext__() @pytest.mark.asyncio -async def test_response_streaming_error_behavior(): +@pytest.mark.parametrize("http_response", ASYNC_HTTP_RESPONSES) +async def test_response_streaming_error_behavior(http_response): # Test to reproduce https://github.com/Azure/azure-sdk-for-python/issues/16723 block_size = 103 total_response_size = 500 diff --git a/sdk/core/azure-core/tests/async_tests/test_universal_http_async.py b/sdk/core/azure-core/tests/async_tests/test_universal_http_async.py index b9229c209fa2..344ebd7dfad0 100644 --- a/sdk/core/azure-core/tests/async_tests/test_universal_http_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_universal_http_async.py @@ -33,7 +33,8 @@ import trio import pytest -from utils import HTTP_REQUESTS, create_http_request +from utils import HTTP_REQUESTS, AIOHTTP_TRANSPORT_RESPONSES, create_transport_response +from azure.core.pipeline._tools import is_rest @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) @@ -92,7 +93,7 @@ async def do(): assert isinstance(response.status_code, int) -def _create_aiohttp_response(body_bytes, headers=None): +def _create_aiohttp_response(http_response, body_bytes, headers=None): class MockAiohttpClientResponse(aiohttp.ClientResponse): def __init__(self, body_bytes, headers=None): self._body = body_bytes @@ -103,29 +104,36 @@ def __init__(self, body_bytes, headers=None): req_response = MockAiohttpClientResponse(body_bytes, headers) - response = AioHttpTransportResponse( + response = create_transport_response( + http_response, None, # Don't need a request here req_response ) - response._body = body_bytes + response._content = body_bytes return response @pytest.mark.asyncio -async def test_aiohttp_response_text(): +@pytest.mark.parametrize("http_response", AIOHTTP_TRANSPORT_RESPONSES) +async def test_aiohttp_response_text(http_response): for encoding in ["utf-8", "utf-8-sig", None]: res = _create_aiohttp_response( + http_response, b'\xef\xbb\xbf56', {'Content-Type': 'text/plain'} ) + if is_rest(http_response): + await res.read() assert res.text(encoding) == '56', "Encoding {} didn't work".format(encoding) @pytest.mark.asyncio -async def test_aiohttp_response_decompression(): +@pytest.mark.parametrize("http_response", AIOHTTP_TRANSPORT_RESPONSES) +async def test_aiohttp_response_decompression(http_response): res = _create_aiohttp_response( + http_response, b"\x1f\x8b\x08\x00\x00\x00\x00\x00\x04\x00\x8d\x8d\xb1n\xc30\x0cD" b"\xff\x85s\x14HVlY\xda\x8av.\n4\x1d\x9a\x8d\xa1\xe5D\x80m\x01\x12=" b"\x14A\xfe\xbd\x92\x81d\xceB\x1c\xef\xf8\x8e7\x08\x038\xf0\xa67Fj+" @@ -146,9 +154,11 @@ async def test_aiohttp_response_decompression(): assert res.body() == expect, "Decompression didn't work" @pytest.mark.asyncio -async def test_aiohttp_response_decompression_negtive(): +@pytest.mark.parametrize("http_response", AIOHTTP_TRANSPORT_RESPONSES) +async def test_aiohttp_response_decompression_negtive(http_response): import zlib res = _create_aiohttp_response( + http_response, b"\xff\x85s\x14HVlY\xda\x8av.\n4\x1d\x9a\x8d\xa1\xe5D\x80m\x01\x12=" b"\x14A\xfe\xbd\x92\x81d\xceB\x1c\xef\xf8\x8e7\x08\x038\xf0\xa67Fj+" b"\x946\x9d8\x0c4\x08{\x96(\x94mzkh\x1cM/a\x07\x94<\xb2\x1f>\xca8\x86" @@ -162,11 +172,14 @@ async def test_aiohttp_response_decompression_negtive(): with pytest.raises(zlib.error): body = res.body() -def test_repr(): +@pytest.mark.parametrize("http_response", AIOHTTP_TRANSPORT_RESPONSES) +def test_repr(http_response): res = _create_aiohttp_response( + http_response, b'\xef\xbb\xbf56', {} ) res.content_type = "text/plain" - assert repr(res) == "" + class_name = "AsyncHttpResponse" if is_rest(http_response) else "AioHttpTransportResponse" + assert repr(res) == f"<{class_name}: 200 OK, Content-Type: text/plain>" diff --git a/sdk/core/azure-core/tests/test_base_polling.py b/sdk/core/azure-core/tests/test_base_polling.py index c75b401df01d..ad030099c6eb 100644 --- a/sdk/core/azure-core/tests/test_base_polling.py +++ b/sdk/core/azure-core/tests/test_base_polling.py @@ -44,11 +44,12 @@ from azure.core.exceptions import DecodeError, HttpResponseError from azure.core import PipelineClient from azure.core.pipeline import PipelineResponse, Pipeline, PipelineContext -from azure.core.pipeline.transport import RequestsTransportResponse, HttpTransport +from azure.core.pipeline.transport import HttpTransport from azure.core.polling.base_polling import LROBasePolling from azure.core.pipeline.policies._utils import _FixedOffset -from utils import HTTP_REQUESTS, is_rest +from utils import request_and_responses_product, REQUESTS_TRANSPORT_RESPONSES, create_transport_response +from azure.core.pipeline._tools import is_rest class SimpleResource: """An implementation of Python 3 SimpleNamespace. @@ -83,8 +84,9 @@ class BadEndpointError(Exception): CLIENT = PipelineClient("http://example.org") CLIENT.http_request_type = None +CLIENT.http_response_type = None def mock_run(client_self, request, **kwargs): - return TestBasePolling.mock_update(client_self.http_request_type, request.url, request.headers) + return TestBasePolling.mock_update(client_self.http_request_type, client_self.http_response_type, request.url, request.headers) CLIENT._pipeline.run = types.MethodType(mock_run, CLIENT) @@ -127,27 +129,30 @@ def cb(pipeline_response): @pytest.fixture def polling_response(): - polling = LROBasePolling() - headers = {} + def _callback(http_response, headers={}): + polling = LROBasePolling() - response = Response() - response.headers = headers - response.status_code = 200 + response = Response() + response.headers = headers + response.status_code = 200 - polling._pipeline_response = PipelineResponse( - None, - RequestsTransportResponse( + response = create_transport_response( + http_response, None, response, - ), - PipelineContext(None) - ) - polling._initial_response = polling._pipeline_response - return polling, headers - + ) + polling._pipeline_response = PipelineResponse( + None, + response, + PipelineContext(None) + ) + polling._initial_response = polling._pipeline_response + return polling + return _callback -def test_base_polling_continuation_token(client, polling_response): - polling, _ = polling_response +@pytest.mark.parametrize("http_response", REQUESTS_TRANSPORT_RESPONSES) +def test_base_polling_continuation_token(client, polling_response, http_response): + polling = polling_response(http_response) continuation_token = polling.get_continuation_token() assert isinstance(continuation_token, six.string_types) @@ -160,20 +165,17 @@ def test_base_polling_continuation_token(client, polling_response): new_polling = LROBasePolling() new_polling.initialize(*polling_args) - -def test_delay_extraction_int(polling_response): - polling, headers = polling_response - - headers['Retry-After'] = "10" +@pytest.mark.parametrize("http_response", REQUESTS_TRANSPORT_RESPONSES) +def test_delay_extraction_int(polling_response, http_response): + polling = polling_response(http_response, {"Retry-After": "10"}) assert polling._extract_delay() == 10 @pytest.mark.skipif(platform.python_implementation() == 'PyPy', reason="https://stackoverflow.com/questions/11146725/isinstance-and-mocking") -def test_delay_extraction_httpdate(polling_response): - polling, headers = polling_response +@pytest.mark.parametrize("http_response", REQUESTS_TRANSPORT_RESPONSES) +def test_delay_extraction_httpdate(polling_response, http_response): + polling = polling_response(http_response, {"Retry-After": "Mon, 20 Nov 1995 19:12:08 -0500"}) - # Test that I need to retry exactly one hour after, by mocking "now" - headers['Retry-After'] = "Mon, 20 Nov 1995 19:12:08 -0500" from datetime import datetime as basedatetime now_mock_datetime = datetime.datetime(1995, 11, 20, 18, 12, 8, tzinfo=_FixedOffset(-5*60)) @@ -184,14 +186,15 @@ def test_delay_extraction_httpdate(polling_response): assert polling._extract_delay() == 60*60 # one hour in seconds assert str(mock_datetime.now.call_args[0][0]) == "" -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_post(pipeline_client_builder, deserialization_cb, http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(REQUESTS_TRANSPORT_RESPONSES)) +def test_post(pipeline_client_builder, deserialization_cb, http_request, http_response): # Test POST LRO with both Location and Operation-Location # The initial response contains both Location and Operation-Location, a 202 and no Body initial_response = TestBasePolling.mock_send( http_request, + http_response, 'POST', 202, { @@ -207,6 +210,7 @@ def send(request, **kwargs): if request.url == 'http://example.org/location': return TestBasePolling.mock_send( http_request, + http_response, 'GET', 200, body={'location_result': True} @@ -214,6 +218,7 @@ def send(request, **kwargs): elif request.url == 'http://example.org/async_monitor': return TestBasePolling.mock_send( http_request, + http_response, 'GET', 200, body={'status': 'Succeeded'} @@ -238,15 +243,18 @@ def send(request, **kwargs): assert request.method == 'GET' if request.url == 'http://example.org/location': - return TestBasePolling.mock_send( + response = TestBasePolling.mock_send( http_request, + http_response, 'GET', 200, body=None ).http_response + return response elif request.url == 'http://example.org/async_monitor': return TestBasePolling.mock_send( http_request, + http_response, 'GET', 200, body={'status': 'Succeeded'} @@ -264,14 +272,15 @@ def send(request, **kwargs): result = poll.result() assert result is None -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_post_resource_location(pipeline_client_builder, deserialization_cb, http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(REQUESTS_TRANSPORT_RESPONSES)) +def test_post_resource_location(pipeline_client_builder, deserialization_cb, http_request, http_response): # ResourceLocation # The initial response contains both Location and Operation-Location, a 202 and no Body initial_response = TestBasePolling.mock_send( http_request, + http_response, 'POST', 202, { @@ -286,6 +295,7 @@ def send(request, **kwargs): if request.url == 'http://example.org/resource_location': return TestBasePolling.mock_send( http_request, + http_response, 'GET', 200, body={'location_result': True} @@ -293,6 +303,7 @@ def send(request, **kwargs): elif request.url == 'http://example.org/async_monitor': return TestBasePolling.mock_send( http_request, + http_response, 'GET', 200, body={'status': 'Succeeded', 'resourceLocation': 'http://example.org/resource_location'} @@ -316,7 +327,7 @@ class TestBasePolling(object): convert = re.compile('([a-z0-9])([A-Z])') @staticmethod - def mock_send(http_request, method, status, headers=None, body=RESPONSE_BODY): + def mock_send(http_request, http_response, method, status, headers=None, body=RESPONSE_BODY): if headers is None: headers = {} response = Response() @@ -350,18 +361,19 @@ def mock_send(http_request, method, status, headers=None, body=RESPONSE_BODY): None, # form_content None # stream_content ) - + response = create_transport_response( + http_response, + request, + response, + ) return PipelineResponse( request, - RequestsTransportResponse( - request, - response, - ), + response, None # context ) @staticmethod - def mock_update(http_request, url, headers=None): + def mock_update(http_request, http_response, url, headers=None): response = Response() response._content_consumed = True response.request = mock.create_autospec(Request) @@ -396,13 +408,14 @@ def mock_update(http_request, url, headers=None): response.request.method, response.request.url, ) - + response = create_transport_response( + http_response, + request, + response, + ) return PipelineResponse( request, - RequestsTransportResponse( - request, - response, - ), + response, None # context ) @@ -436,14 +449,16 @@ def mock_deserialization_no_body(pipeline_response): """ return None - @pytest.mark.parametrize("http_request", HTTP_REQUESTS) - def test_long_running_put(self, http_request): + @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(REQUESTS_TRANSPORT_RESPONSES)) + def test_long_running_put(self, http_request, http_response): #TODO: Test custom header field # Test throw on non LRO related status code response = TestBasePolling.mock_send( - http_request, 'PUT', 1000, {}) + http_request, + http_response, 'PUT', 1000, {}) CLIENT.http_request_type = http_request + CLIENT.http_response_type = http_response with pytest.raises(HttpResponseError): LROPoller(CLIENT, response, TestBasePolling.mock_outputs, @@ -456,6 +471,7 @@ def test_long_running_put(self, http_request): } response = TestBasePolling.mock_send( http_request, + http_response, 'PUT', 201, {}, response_body ) @@ -471,6 +487,7 @@ def no_update_allowed(url, headers=None): # Test polling from operation-location header response = TestBasePolling.mock_send( http_request, + http_response, 'PUT', 201, {'operation-location': ASYNC_URL}) poll = LROPoller(CLIENT, response, @@ -482,6 +499,7 @@ def no_update_allowed(url, headers=None): # Test polling location header response = TestBasePolling.mock_send( http_request, + http_response, 'PUT', 201, {'location': LOCATION_URL}) poll = LROPoller(CLIENT, response, @@ -494,6 +512,7 @@ def no_update_allowed(url, headers=None): response_body = {} # Empty will raise response = TestBasePolling.mock_send( http_request, + http_response, 'PUT', 201, {'location': LOCATION_URL}, response_body) poll = LROPoller(CLIENT, response, @@ -505,6 +524,7 @@ def no_update_allowed(url, headers=None): # Test fail to poll from operation-location header response = TestBasePolling.mock_send( http_request, + http_response, 'PUT', 201, {'operation-location': ERROR}) with pytest.raises(BadEndpointError): @@ -515,6 +535,7 @@ def no_update_allowed(url, headers=None): # Test fail to poll from location header response = TestBasePolling.mock_send( http_request, + http_response, 'PUT', 201, {'location': ERROR}) with pytest.raises(BadEndpointError): @@ -522,12 +543,14 @@ def no_update_allowed(url, headers=None): TestBasePolling.mock_outputs, LROBasePolling(0)).result() - @pytest.mark.parametrize("http_request", HTTP_REQUESTS) - def test_long_running_patch(self, http_request): + @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(REQUESTS_TRANSPORT_RESPONSES)) + def test_long_running_patch(self, http_request, http_response): CLIENT.http_request_type = http_request + CLIENT.http_response_type = http_response # Test polling from location header response = TestBasePolling.mock_send( http_request, + http_response, 'PATCH', 202, {'location': LOCATION_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -540,6 +563,7 @@ def test_long_running_patch(self, http_request): # Test polling from operation-location header response = TestBasePolling.mock_send( http_request, + http_response, 'PATCH', 202, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -552,6 +576,7 @@ def test_long_running_patch(self, http_request): # Test polling from location header response = TestBasePolling.mock_send( http_request, + http_response, 'PATCH', 200, {'location': LOCATION_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -564,6 +589,7 @@ def test_long_running_patch(self, http_request): # Test polling from operation-location header response = TestBasePolling.mock_send( http_request, + http_response, 'PATCH', 200, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -576,6 +602,7 @@ def test_long_running_patch(self, http_request): # Test fail to poll from operation-location header response = TestBasePolling.mock_send( http_request, + http_response, 'PATCH', 202, {'operation-location': ERROR}) with pytest.raises(BadEndpointError): @@ -586,6 +613,7 @@ def test_long_running_patch(self, http_request): # Test fail to poll from location header response = TestBasePolling.mock_send( http_request, + http_response, 'PATCH', 202, {'location': ERROR}) with pytest.raises(BadEndpointError): @@ -593,33 +621,37 @@ def test_long_running_patch(self, http_request): TestBasePolling.mock_outputs, LROBasePolling(0)).result() - @pytest.mark.parametrize("http_request", HTTP_REQUESTS) - def test_long_running_delete(self, http_request): + @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(REQUESTS_TRANSPORT_RESPONSES)) + def test_long_running_delete(self, http_request, http_response): # Test polling from operation-location header response = TestBasePolling.mock_send( http_request, + http_response, 'DELETE', 202, {'operation-location': ASYNC_URL}, body="" ) CLIENT.http_request_type = http_request + CLIENT.http_response_type = http_response poll = LROPoller(CLIENT, response, TestBasePolling.mock_deserialization_no_body, LROBasePolling(0)) poll.wait() assert poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader is None - @pytest.mark.parametrize("http_request", HTTP_REQUESTS) - def test_long_running_post_legacy(self, http_request): + @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(REQUESTS_TRANSPORT_RESPONSES)) + def test_long_running_post_legacy(self, http_request, http_response): # Former oooooold tests to refactor one day to something more readble # Test polling from operation-location header response = TestBasePolling.mock_send( http_request, + http_response, 'POST', 201, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) CLIENT.http_request_type = http_request + CLIENT.http_response_type = http_response poll = LROPoller(CLIENT, response, TestBasePolling.mock_deserialization_no_body, LROBasePolling(0)) @@ -629,6 +661,7 @@ def test_long_running_post_legacy(self, http_request): # Test polling from operation-location header response = TestBasePolling.mock_send( http_request, + http_response, 'POST', 202, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -641,6 +674,7 @@ def test_long_running_post_legacy(self, http_request): # Test polling from location header response = TestBasePolling.mock_send( http_request, + http_response, 'POST', 202, {'location': LOCATION_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -653,6 +687,7 @@ def test_long_running_post_legacy(self, http_request): # Test fail to poll from operation-location header response = TestBasePolling.mock_send( http_request, + http_response, 'POST', 202, {'operation-location': ERROR}) with pytest.raises(BadEndpointError): @@ -663,6 +698,7 @@ def test_long_running_post_legacy(self, http_request): # Test fail to poll from location header response = TestBasePolling.mock_send( http_request, + http_response, 'POST', 202, {'location': ERROR}) with pytest.raises(BadEndpointError): @@ -670,15 +706,17 @@ def test_long_running_post_legacy(self, http_request): TestBasePolling.mock_outputs, LROBasePolling(0)).result() - @pytest.mark.parametrize("http_request", HTTP_REQUESTS) - def test_long_running_negative(self, http_request): + @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(REQUESTS_TRANSPORT_RESPONSES)) + def test_long_running_negative(self, http_request, http_response): global LOCATION_BODY global POLLING_STATUS CLIENT.http_request_type = http_request + CLIENT.http_response_type = http_response # Test LRO PUT throws for invalid json LOCATION_BODY = '{' response = TestBasePolling.mock_send( http_request, + http_response, 'POST', 202, {'location': LOCATION_URL}) poll = LROPoller( @@ -693,6 +731,7 @@ def test_long_running_negative(self, http_request): LOCATION_BODY = '{\'"}' response = TestBasePolling.mock_send( http_request, + http_response, 'POST', 202, {'location': LOCATION_URL}) poll = LROPoller(CLIENT, response, @@ -705,6 +744,7 @@ def test_long_running_negative(self, http_request): POLLING_STATUS = 203 response = TestBasePolling.mock_send( http_request, + http_response, 'POST', 202, {'location': LOCATION_URL}) poll = LROPoller(CLIENT, response, diff --git a/sdk/core/azure-core/tests/test_basic_transport.py b/sdk/core/azure-core/tests/test_basic_transport.py index 27634b3bae58..a9f2d0b85f19 100644 --- a/sdk/core/azure-core/tests/test_basic_transport.py +++ b/sdk/core/azure-core/tests/test_basic_transport.py @@ -12,25 +12,52 @@ except ImportError: import mock -from azure.core.pipeline.transport import HttpResponse, RequestsTransport -from azure.core.pipeline.transport._base import HttpClientTransportResponse, HttpTransport, _deserialize_response, _urljoin +from azure.core.pipeline.transport import HttpResponse as PipelineTransportHttpResponse, RequestsTransport +from azure.core.pipeline.transport._base import HttpTransport, _deserialize_response, _urljoin from azure.core.pipeline.policies import HeadersPolicy from azure.core.pipeline import Pipeline from azure.core.exceptions import HttpResponseError import logging import pytest -from utils import HTTP_REQUESTS +from utils import HTTP_REQUESTS, request_and_responses_product, HTTP_CLIENT_TRANSPORT_RESPONSES, create_transport_response +from azure.core.rest._http_response_impl import HttpResponseImpl as RestHttpResponseImpl +from azure.core.pipeline._tools import is_rest -class MockResponse(HttpResponse): +class PipelineTransportMockResponse(PipelineTransportHttpResponse): def __init__(self, request, body, content_type): - super(MockResponse, self).__init__(request, None) + super(PipelineTransportMockResponse, self).__init__(request, None) self._body = body self.content_type = content_type def body(self): return self._body +class RestMockResponse(RestHttpResponseImpl): + def __init__(self, request, body, content_type): + super(RestMockResponse, self).__init__( + request=request, + internal_response=None, + content_type=content_type, + block_size=None, + status_code=200, + reason="OK", + headers={}, + stream_download_generator=None, + ) + # the impl takes in a lot more kwargs. It's not public and is a + # helper implementation shared across our azure core transport responses + self._body = body + + def body(self): + return self._body + + @property + def content(self): + return self._body + +MOCK_RESPONSES = [PipelineTransportMockResponse, RestMockResponse] + @pytest.mark.skipif(sys.version_info < (3, 6), reason="Multipart serialization not supported on 2.7 + dict order not deterministic on 3.5") @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_http_request_serialization(http_request): @@ -97,8 +124,8 @@ def test_url_join(http_request): assert _urljoin('devstoreaccount1/', 'testdir/') == 'devstoreaccount1/testdir/' -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_http_client_response(port, http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_CLIENT_TRANSPORT_RESPONSES)) +def test_http_client_response(port, http_request, http_response): # Create a core request request = http_request("GET", "http://localhost:{}".format(port)) @@ -107,7 +134,9 @@ def test_http_client_response(port, http_request): conn.request("GET", "/get") r1 = conn.getresponse() - response = HttpClientTransportResponse(request, r1) + response = create_transport_response(http_response, request, r1) + if is_rest(http_response): + response.read() # Don't assume too much in those assert, since we reach a real server assert response.internal_response is r1 @@ -621,8 +650,8 @@ def test_multipart_send_with_combination_changeset_middle(http_request): ) -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_multipart_receive(http_request): +@pytest.mark.parametrize("http_request,mock_response", request_and_responses_product(MOCK_RESPONSES)) +def test_multipart_receive(http_request, mock_response): class ResponsePolicy(object): def on_response(self, request, response): @@ -666,7 +695,7 @@ def on_response(self, request, response): "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--" ) - response = MockResponse( + response = mock_response( request, body_as_str.encode('ascii'), "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -684,22 +713,22 @@ def on_response(self, request, response): assert res1.status_code == 404 assert res1.headers['x-ms-fun'] == 'true' -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_raise_for_status_bad_response(http_request): - response = MockResponse(request=None, body=None, content_type=None) +@pytest.mark.parametrize("mock_response", MOCK_RESPONSES) +def test_raise_for_status_bad_response(mock_response): + response = mock_response(request=None, body=None, content_type=None) response.status_code = 400 with pytest.raises(HttpResponseError): response.raise_for_status() -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_raise_for_status_good_response(http_request): - response = MockResponse(request=None, body=None, content_type=None) +@pytest.mark.parametrize("mock_response", MOCK_RESPONSES) +def test_raise_for_status_good_response(mock_response): + response = mock_response(request=None, body=None, content_type=None) response.status_code = 200 response.raise_for_status() -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_multipart_receive_with_one_changeset(http_request): +@pytest.mark.parametrize("http_request,mock_response", request_and_responses_product(MOCK_RESPONSES)) +def test_multipart_receive_with_one_changeset(http_request, mock_response): changeset = http_request(None, None) changeset.set_multipart_mixed( @@ -739,7 +768,7 @@ def test_multipart_receive_with_one_changeset(http_request): b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -754,8 +783,8 @@ def test_multipart_receive_with_one_changeset(http_request): assert res0.status_code == 202 -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_multipart_receive_with_multiple_changesets(http_request): +@pytest.mark.parametrize("http_request,mock_response", request_and_responses_product(MOCK_RESPONSES)) +def test_multipart_receive_with_multiple_changesets(http_request, mock_response): changeset1 = http_request(None, None) changeset1.set_multipart_mixed( @@ -824,7 +853,7 @@ def test_multipart_receive_with_multiple_changesets(http_request): b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -840,8 +869,8 @@ def test_multipart_receive_with_multiple_changesets(http_request): assert parts[3].status_code == 409 -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_multipart_receive_with_combination_changeset_first(http_request): +@pytest.mark.parametrize("http_request,mock_response", request_and_responses_product(MOCK_RESPONSES)) +def test_multipart_receive_with_combination_changeset_first(http_request, mock_response): changeset = http_request(None, None) changeset.set_multipart_mixed( @@ -890,7 +919,7 @@ def test_multipart_receive_with_combination_changeset_first(http_request): b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -905,8 +934,8 @@ def test_multipart_receive_with_combination_changeset_first(http_request): assert parts[2].status_code == 404 -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_multipart_receive_with_combination_changeset_middle(http_request): +@pytest.mark.parametrize("http_request,mock_response", request_and_responses_product(MOCK_RESPONSES)) +def test_multipart_receive_with_combination_changeset_middle(http_request, mock_response): changeset = http_request(None, None) changeset.set_multipart_mixed(http_request("DELETE", "/container1/blob1")) @@ -956,7 +985,7 @@ def test_multipart_receive_with_combination_changeset_middle(http_request): b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -971,8 +1000,8 @@ def test_multipart_receive_with_combination_changeset_middle(http_request): assert parts[2].status_code == 404 -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_multipart_receive_with_combination_changeset_last(http_request): +@pytest.mark.parametrize("http_request,mock_response", request_and_responses_product(MOCK_RESPONSES)) +def test_multipart_receive_with_combination_changeset_last(http_request, mock_response): changeset = http_request(None, None) changeset.set_multipart_mixed( @@ -1022,7 +1051,7 @@ def test_multipart_receive_with_combination_changeset_last(http_request): b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -1037,8 +1066,8 @@ def test_multipart_receive_with_combination_changeset_last(http_request): assert parts[2].status_code == 404 -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_multipart_receive_with_bom(http_request): +@pytest.mark.parametrize("http_request,mock_response", request_and_responses_product(MOCK_RESPONSES)) +def test_multipart_receive_with_bom(http_request, mock_response): req0 = http_request("DELETE", "/container0/blob0") @@ -1060,7 +1089,7 @@ def test_multipart_receive_with_bom(http_request): b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--" ) - response = MockResponse( + response = mock_response( request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" @@ -1074,8 +1103,8 @@ def test_multipart_receive_with_bom(http_request): assert res0.body().startswith(b'\xef\xbb\xbf') -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_recursive_multipart_receive(http_request): +@pytest.mark.parametrize("http_request,mock_response", request_and_responses_product(MOCK_RESPONSES)) +def test_recursive_multipart_receive(http_request, mock_response): req0 = http_request("DELETE", "/container0/blob0") internal_req0 = http_request("DELETE", "/container0/blob0") req0.set_multipart_mixed(internal_req0) @@ -1106,7 +1135,7 @@ def test_recursive_multipart_receive(http_request): "--batchresponse_8d5f5bcd-2cb5-44bb-91b5-e9a722e68cb6--" ).format(internal_body_as_str) - response = MockResponse( + response = mock_response( request, body_as_str.encode('ascii'), "multipart/mixed; boundary=batchresponse_8d5f5bcd-2cb5-44bb-91b5-e9a722e68cb6" @@ -1171,6 +1200,6 @@ def test_conflict_timeout(caplog, port, http_request): def test_aiohttp_loop(): import asyncio from azure.core.pipeline.transport import AioHttpTransport - loop = asyncio.get_event_loop() + loop = asyncio._get_running_loop() with pytest.raises(ValueError): transport = AioHttpTransport(loop=loop) diff --git a/sdk/core/azure-core/tests/test_error_map.py b/sdk/core/azure-core/tests/test_error_map.py index f3f4e4c848ae..267ab1a5ee62 100644 --- a/sdk/core/azure-core/tests/test_error_map.py +++ b/sdk/core/azure-core/tests/test_error_map.py @@ -30,45 +30,42 @@ map_error, ErrorMap, ) -from azure.core.pipeline.transport import ( - HttpResponse, -) -from utils import HTTP_REQUESTS +from utils import request_and_responses_product, create_http_response, HTTP_RESPONSES -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_error_map(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) +def test_error_map(http_request, http_response): request = http_request("GET", "") - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) error_map = { 404: ResourceNotFoundError } with pytest.raises(ResourceNotFoundError): map_error(404, response, error_map) -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_error_map_no_default(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) +def test_error_map_no_default(http_request, http_response): request = http_request("GET", "") - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) error_map = ErrorMap({ 404: ResourceNotFoundError }) with pytest.raises(ResourceNotFoundError): map_error(404, response, error_map) -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_error_map_with_default(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) +def test_error_map_with_default(http_request, http_response): request = http_request("GET", "") - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) error_map = ErrorMap({ 404: ResourceNotFoundError }, default_error=ResourceExistsError) with pytest.raises(ResourceExistsError): map_error(401, response, error_map) -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_only_default(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) +def test_only_default(http_request, http_response): request = http_request("GET", "") - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) error_map = ErrorMap(default_error=ResourceExistsError) with pytest.raises(ResourceExistsError): map_error(401, response, error_map) diff --git a/sdk/core/azure-core/tests/test_exceptions.py b/sdk/core/azure-core/tests/test_exceptions.py index a5d494e559b7..02944bb04b4b 100644 --- a/sdk/core/azure-core/tests/test_exceptions.py +++ b/sdk/core/azure-core/tests/test_exceptions.py @@ -23,6 +23,7 @@ # THE SOFTWARE. # # -------------------------------------------------------------------------- +import pytest import json import requests try: @@ -34,26 +35,44 @@ # module under test from azure.core.exceptions import HttpResponseError, ODataV4Error, ODataV4Format from azure.core.pipeline.transport import RequestsTransportResponse -from azure.core.pipeline.transport._base import _HttpResponseBase - - -def _build_response(json_body): - class MockResponse(_HttpResponseBase): - def __init__(self): - super(MockResponse, self).__init__( - request=None, - internal_response = None, - ) - self.status_code = 400 - self.reason = "Bad Request" - self.content_type = "application/json" - self._body = json_body - - def body(self): - return self._body - - return MockResponse() - +from azure.core.pipeline.transport._base import _HttpResponseBase as PipelineTransportHttpResponseBase +from azure.core.rest._http_response_impl import _HttpResponseBaseImpl as RestHttpResponseBase + +class PipelineTransportMockResponse(PipelineTransportHttpResponseBase): + def __init__(self, json_body): + super(PipelineTransportMockResponse, self).__init__( + request=None, + internal_response = None, + ) + self.status_code = 400 + self.reason = "Bad Request" + self.content_type = "application/json" + self._body = json_body + + def body(self): + return self._body + +class RestMockResponse(RestHttpResponseBase): + def __init__(self, json_body): + super(RestMockResponse, self).__init__( + request=None, + internal_response=None, + status_code=400, + reason="Bad Request", + content_type="application/json", + headers={}, + stream_download_generator=None, + ) + self._body = json_body + + def body(self): + return self._body + + @property + def content(self): + return self._body + +MOCK_RESPONSES = [PipelineTransportMockResponse, RestMockResponse] class FakeErrorOne(object): @@ -105,7 +124,8 @@ def test_error_continuation_token(self): assert error.status_code is None assert error.continuation_token == 'foo' - def test_deserialized_httpresponse_error_code(self): + @pytest.mark.parametrize("mock_response", MOCK_RESPONSES) + def test_deserialized_httpresponse_error_code(self, mock_response): """This is backward compat support of autorest azure-core (KV 4.0.0, Storage 12.0.0). Do NOT adapt this test unless you know what you're doing. @@ -116,7 +136,7 @@ def test_deserialized_httpresponse_error_code(self): "message": "A fake error", } } - response = _build_response(json.dumps(message).encode("utf-8")) + response = mock_response(json.dumps(message).encode("utf-8")) error = FakeHttpResponse(response, FakeErrorOne()) assert "(FakeErrorOne) A fake error" in error.message assert "(FakeErrorOne) A fake error" in str(error.error) @@ -133,7 +153,8 @@ def test_deserialized_httpresponse_error_code(self): assert error.error.error.message == "A fake error" - def test_deserialized_httpresponse_error_message(self): + @pytest.mark.parametrize("mock_response", MOCK_RESPONSES) + def test_deserialized_httpresponse_error_message(self, mock_response): """This is backward compat support for weird responses, adn even if it's likely just the autorest testserver, should be fine parsing. @@ -143,7 +164,7 @@ def test_deserialized_httpresponse_error_message(self): "code": "FakeErrorTwo", "message": "A different fake error", } - response = _build_response(json.dumps(message).encode("utf-8")) + response = mock_response(json.dumps(message).encode("utf-8")) error = FakeHttpResponse(response, FakeErrorTwo()) assert "(FakeErrorTwo) A different fake error" in error.message assert "(FakeErrorTwo) A different fake error" in str(error.error) @@ -155,7 +176,8 @@ def test_deserialized_httpresponse_error_message(self): assert isinstance(error.model, FakeErrorTwo) assert isinstance(error.error, ODataV4Format) - def test_httpresponse_error_with_response(self, port): + @pytest.mark.parametrize("mock_response", MOCK_RESPONSES) + def test_httpresponse_error_with_response(self, port, mock_response): response = requests.get("http://localhost:{}/basic/string".format(port)) http_response = RequestsTransportResponse(None, response) @@ -166,7 +188,8 @@ def test_httpresponse_error_with_response(self, port): assert isinstance(error.status_code, int) assert error.error is None - def test_odata_v4_exception(self): + @pytest.mark.parametrize("mock_response", MOCK_RESPONSES) + def test_odata_v4_exception(self, mock_response): message = { "error": { "code": "501", @@ -183,7 +206,7 @@ def test_odata_v4_exception(self): } } } - exp = ODataV4Error(_build_response(json.dumps(message).encode("utf-8"))) + exp = ODataV4Error(mock_response(json.dumps(message).encode("utf-8"))) assert exp.code == "501" assert exp.message == "Unsupported functionality" @@ -194,14 +217,15 @@ def test_odata_v4_exception(self): assert "context" in exp.innererror message = {} - exp = ODataV4Error(_build_response(json.dumps(message).encode("utf-8"))) + exp = ODataV4Error(mock_response(json.dumps(message).encode("utf-8"))) assert exp.message == "Operation returned an invalid status 'Bad Request'" - exp = ODataV4Error(_build_response(b"")) + exp = ODataV4Error(mock_response(b"")) assert exp.message == "Operation returned an invalid status 'Bad Request'" assert str(exp) == "Operation returned an invalid status 'Bad Request'" - def test_odata_v4_minimal(self): + @pytest.mark.parametrize("mock_response", MOCK_RESPONSES) + def test_odata_v4_minimal(self, mock_response): """Minimal valid OData v4 is code/message and nothing else. """ message = { @@ -210,14 +234,15 @@ def test_odata_v4_minimal(self): "message": "Unsupported functionality", } } - exp = ODataV4Error(_build_response(json.dumps(message).encode("utf-8"))) + exp = ODataV4Error(mock_response(json.dumps(message).encode("utf-8"))) assert exp.code == "501" assert exp.message == "Unsupported functionality" assert exp.target is None assert exp.details == [] assert exp.innererror == {} - def test_broken_odata_details(self): + @pytest.mark.parametrize("mock_response", MOCK_RESPONSES) + def test_broken_odata_details(self, mock_response): """Do not block creating a nice exception if "details" only is broken """ message = { @@ -244,10 +269,11 @@ def test_broken_odata_details(self): "innererror": None, } } - exp = HttpResponseError(response=_build_response(json.dumps(message).encode("utf-8"))) + exp = HttpResponseError(response=mock_response(json.dumps(message).encode("utf-8"))) assert exp.error.code == "Conflict" - def test_null_odata_details(self): + @pytest.mark.parametrize("mock_response", MOCK_RESPONSES) + def test_null_odata_details(self, mock_response): message = { "error": { "code": "501", @@ -257,5 +283,5 @@ def test_null_odata_details(self): "innererror": None, } } - exp = HttpResponseError(response=_build_response(json.dumps(message).encode("utf-8"))) + exp = HttpResponseError(response=mock_response(json.dumps(message).encode("utf-8"))) assert exp.error.code == "501" \ No newline at end of file diff --git a/sdk/core/azure-core/tests/test_http_logging_policy.py b/sdk/core/azure-core/tests/test_http_logging_policy.py index 17f2a61e6953..05b1f9491ab0 100644 --- a/sdk/core/azure-core/tests/test_http_logging_policy.py +++ b/sdk/core/azure-core/tests/test_http_logging_policy.py @@ -15,17 +15,14 @@ PipelineRequest, PipelineContext ) -from azure.core.pipeline.transport import ( - HttpResponse, -) from azure.core.pipeline.policies import ( HttpLoggingPolicy, ) -from utils import HTTP_REQUESTS - +from utils import HTTP_RESPONSES, create_http_response, request_and_responses_product +from azure.core.pipeline._tools import is_rest -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_http_logger(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) +def test_http_logger(http_request, http_response): class MockHandler(logging.Handler): def __init__(self): @@ -44,7 +41,7 @@ def emit(self, record): policy = HttpLoggingPolicy(logger=logger) universal_request = http_request('GET', 'http://localhost/') - http_response = HttpResponse(universal_request, None) + http_response = create_http_response(http_response, universal_request, None) http_response.status_code = 202 request = PipelineRequest(universal_request, PipelineContext(None)) @@ -147,8 +144,8 @@ def emit(self, record): -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_http_logger_operation_level(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) +def test_http_logger_operation_level(http_request, http_response): class MockHandler(logging.Handler): def __init__(self): @@ -168,7 +165,7 @@ def emit(self, record): kwargs={'logger': logger} universal_request = http_request('GET', 'http://localhost/') - http_response = HttpResponse(universal_request, None) + http_response = create_http_response(http_response, universal_request, None) http_response.status_code = 202 request = PipelineRequest(universal_request, PipelineContext(None, **kwargs)) @@ -225,8 +222,8 @@ def emit(self, record): mock_handler.reset() -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_http_logger_with_body(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) +def test_http_logger_with_body(http_request, http_response): class MockHandler(logging.Handler): def __init__(self): @@ -246,7 +243,7 @@ def emit(self, record): universal_request = http_request('GET', 'http://localhost/') universal_request.body = "testbody" - http_response = HttpResponse(universal_request, None) + http_response = create_http_response(http_response, universal_request, None) http_response.status_code = 202 request = PipelineRequest(universal_request, PipelineContext(None)) @@ -268,8 +265,8 @@ def emit(self, record): mock_handler.reset() -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_http_logger_with_generator_body(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) +def test_http_logger_with_generator_body(http_request, http_response): class MockHandler(logging.Handler): def __init__(self): @@ -291,7 +288,7 @@ def emit(self, record): mock = Mock() mock.__class__ = types.GeneratorType universal_request.body = mock - http_response = HttpResponse(universal_request, None) + http_response = create_http_response(http_response, universal_request, None) http_response.status_code = 202 request = PipelineRequest(universal_request, PipelineContext(None)) diff --git a/sdk/core/azure-core/tests/test_pipeline.py b/sdk/core/azure-core/tests/test_pipeline.py index fb5a063e4f7b..e8050d3ffb97 100644 --- a/sdk/core/azure-core/tests/test_pipeline.py +++ b/sdk/core/azure-core/tests/test_pipeline.py @@ -379,6 +379,8 @@ def test_basic_requests(port, http_request): ] with Pipeline(RequestsTransport(), policies=policies) as pipeline: response = pipeline.run(request) + if is_rest(request): + assert is_rest(response.http_response) assert pipeline._transport.session is None assert isinstance(response.http_response.status_code, int) @@ -393,6 +395,8 @@ def test_basic_options_requests(port, http_request): ] with Pipeline(RequestsTransport(), policies=policies) as pipeline: response = pipeline.run(request) + if is_rest(request): + assert is_rest(response.http_response) assert pipeline._transport.session is None assert isinstance(response.http_response.status_code, int) @@ -409,6 +413,8 @@ def test_basic_requests_separate_session(port, http_request): transport = RequestsTransport(session=session, session_owner=False) with Pipeline(transport, policies=policies) as pipeline: response = pipeline.run(request) + if is_rest(request): + assert is_rest(response.http_response) assert transport.session assert isinstance(response.http_response.status_code, int) diff --git a/sdk/core/azure-core/tests/test_requests_universal.py b/sdk/core/azure-core/tests/test_requests_universal.py index 891a364cfd61..80dbf66299f8 100644 --- a/sdk/core/azure-core/tests/test_requests_universal.py +++ b/sdk/core/azure-core/tests/test_requests_universal.py @@ -26,8 +26,9 @@ import concurrent.futures import requests.utils import pytest -from azure.core.pipeline.transport import RequestsTransport, RequestsTransportResponse -from utils import HTTP_REQUESTS +from azure.core.pipeline.transport import RequestsTransport +from utils import HTTP_REQUESTS, REQUESTS_TRANSPORT_RESPONSES, create_transport_response +from azure.core.pipeline._tools import is_rest def test_threading_basic_requests(): @@ -53,7 +54,7 @@ def test_requests_auto_headers(port, http_request): auto_headers = response.internal_response.request.headers assert 'Content-Type' not in auto_headers -def _create_requests_response(body_bytes, headers=None): +def _create_requests_response(http_response, body_bytes, headers=None): # https://github.com/psf/requests/blob/67a7b2e8336951d527e223429672354989384197/requests/adapters.py#L255 req_response = requests.Response() req_response._content = body_bytes @@ -65,27 +66,34 @@ def _create_requests_response(body_bytes, headers=None): req_response.headers.update(headers) req_response.encoding = requests.utils.get_encoding_from_headers(req_response.headers) - response = RequestsTransportResponse( + response = create_transport_response( + http_response, None, # Don't need a request here req_response ) return response - -def test_requests_response_text(): +@pytest.mark.parametrize("http_response", REQUESTS_TRANSPORT_RESPONSES) +def test_requests_response_text(http_response): for encoding in ["utf-8", "utf-8-sig", None]: res = _create_requests_response( + http_response, b'\xef\xbb\xbf56', {'Content-Type': 'text/plain'} ) + if is_rest(http_response): + res.read() assert res.text(encoding) == '56', "Encoding {} didn't work".format(encoding) -def test_repr(): +@pytest.mark.parametrize("http_response", REQUESTS_TRANSPORT_RESPONSES) +def test_repr(http_response): res = _create_requests_response( + http_response, b'\xef\xbb\xbf56', {'Content-Type': 'text/plain'} ) - assert repr(res) == "" + class_name = "HttpResponse" if is_rest(http_response) else "RequestsTransportResponse" + assert repr(res) == "<{}: 200 OK, Content-Type: text/plain>".format(class_name) diff --git a/sdk/core/azure-core/tests/test_rest_http_request.py b/sdk/core/azure-core/tests/test_rest_http_request.py index 0f16061badbd..1072b7b49e42 100644 --- a/sdk/core/azure-core/tests/test_rest_http_request.py +++ b/sdk/core/azure-core/tests/test_rest_http_request.py @@ -20,7 +20,7 @@ from azure.core.pipeline.policies import ( CustomHookPolicy, UserAgentPolicy, SansIOHTTPPolicy, RetryPolicy ) -from utils import is_rest +from azure.core.pipeline._tools import is_rest from rest_client import TestRestClient from azure.core import PipelineClient diff --git a/sdk/core/azure-core/tests/test_rest_http_response.py b/sdk/core/azure-core/tests/test_rest_http_response.py index db411943ec1c..f5e706cafc3e 100644 --- a/sdk/core/azure-core/tests/test_rest_http_response.py +++ b/sdk/core/azure-core/tests/test_rest_http_response.py @@ -155,15 +155,6 @@ def test_response_no_charset_with_iso_8859_1_content(send_request): assert response.text() == u"Accented: �sterreich" assert response.encoding is None -def test_response_set_explicit_encoding(send_request): - # Deliberately incorrect charset - response = send_request( - request=HttpRequest("GET", "/encoding/latin-1-with-utf-8"), - ) - assert response.headers["Content-Type"] == "text/plain; charset=utf-8" - response.encoding = "latin-1" - assert response.text() == u"Latin 1: ÿ" - assert response.encoding == "latin-1" def test_json(send_request): response = send_request( @@ -337,4 +328,5 @@ def test_readonly(send_request): response = send_request(HttpRequest("GET", "/health")) assert isinstance(response, RestRequestsTransportResponse) - readonly_checks(response) + from azure.core.pipeline.transport import RequestsTransportResponse + readonly_checks(response, old_response_class=RequestsTransportResponse) diff --git a/sdk/core/azure-core/tests/test_rest_backcompat.py b/sdk/core/azure-core/tests/test_rest_request_backcompat.py similarity index 100% rename from sdk/core/azure-core/tests/test_rest_backcompat.py rename to sdk/core/azure-core/tests/test_rest_request_backcompat.py diff --git a/sdk/core/azure-core/tests/test_rest_response_backcompat.py b/sdk/core/azure-core/tests/test_rest_response_backcompat.py new file mode 100644 index 000000000000..b2179abf4f91 --- /dev/null +++ b/sdk/core/azure-core/tests/test_rest_response_backcompat.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import sys +from rest_client import TestRestClient +import pytest +from azure.core.pipeline.transport import HttpRequest as PipelineTransportHttpRequest +from azure.core.rest import HttpRequest as RestHttpRequest +from azure.core.pipeline import Pipeline +from azure.core.pipeline.transport import RequestsTransport + +@pytest.fixture +def old_request(port): + return PipelineTransportHttpRequest("GET", "http://localhost:{}/streams/basic".format(port)) + +@pytest.fixture +def old_response(old_request): + return RequestsTransport().send(old_request) + +@pytest.fixture +def new_request(port): + return RestHttpRequest("GET", "http://localhost:{}/streams/basic".format(port)) + +@pytest.fixture +def new_response(new_request): + return RequestsTransport().send(new_request) + +def test_response_attr_parity(old_response, new_response): + for attr in dir(old_response): + if not attr[0] == "_": + # if not a private attr, we want partiy + assert hasattr(new_response, attr) + +def test_response_set_attrs(old_response, new_response): + for attr in dir(old_response): + if attr[0] == "_": + continue + try: + # if we can set it on the old request, we want to + # be able to set it on the new + setattr(old_response, attr, "foo") + except: + pass + else: + setattr(new_response, attr, "foo") + assert getattr(old_response, attr) == getattr(new_response, attr) == "foo" + +def test_response_block_size(old_response, new_response): + assert old_response.block_size == new_response.block_size == 4096 + old_response.block_size = 500 + new_response.block_size = 500 + assert old_response.block_size == new_response.block_size == 500 + +def test_response_body(old_response, new_response): + assert old_response.body() == new_response.body() == b"Hello, world!" + +def test_response_internal_response(old_response, new_response, port): + assert old_response.internal_response.url == new_response.internal_response.url == "http://localhost:{}/streams/basic".format(port) + old_response.internal_response = "foo" + new_response.internal_response = "foo" + assert old_response.internal_response == new_response.internal_response == "foo" + +def test_response_stream_download(old_request, new_request): + transport = RequestsTransport() + pipeline = Pipeline(transport) + + old_response = transport.send(old_request, stream=True) + old_string = b"".join(old_response.stream_download(pipeline=pipeline)) + + new_response = transport.send(new_request, stream=True) + new_string = b"".join(new_response.stream_download(pipeline)) + assert old_string == new_string == b"Hello, world!" + +def test_response_request(old_response, new_response, port): + assert old_response.request.url == new_response.request.url == "http://localhost:{}/streams/basic".format(port) + old_response.request = "foo" + new_response.request = "foo" + assert old_response.request == new_response.request == "foo" + +def test_response_status_code(old_response, new_response): + assert old_response.status_code == new_response.status_code == 200 + old_response.status_code = 202 + new_response.status_code = 202 + assert old_response.status_code == new_response.status_code == 202 + +def test_response_headers(old_response, new_response): + assert set(old_response.headers.keys()) == set(new_response.headers.keys()) == set(["Content-Type", "Connection", "Server", "Date"]) + old_response.headers = {"Hello": "world!"} + new_response.headers = {"Hello": "world!"} + assert old_response.headers == new_response.headers == {"Hello": "world!"} + +def test_response_reason(old_response, new_response): + assert old_response.reason == new_response.reason == "OK" + old_response.reason = "Not OK" + new_response.reason = "Not OK" + assert old_response.reason == new_response.reason == "Not OK" + +def test_response_content_type(old_response, new_response): + assert old_response.content_type == new_response.content_type == "text/html; charset=utf-8" + old_response.content_type = "application/json" + new_response.content_type = "application/json" + assert old_response.content_type == new_response.content_type == "application/json" + +def _create_multiapart_request(http_request_class): + class ResponsePolicy(object): + + def on_request(self, *args): + return + + def on_response(self, request, response): + response.http_response.headers['x-ms-fun'] = 'true' + + req0 = http_request_class("DELETE", "/container0/blob0") + req1 = http_request_class("DELETE", "/container1/blob1") + request = http_request_class("POST", "/multipart/request") + request.set_multipart_mixed(req0, req1, policies=[ResponsePolicy()]) + return request + +def _test_parts(response): + # hack the content type + parts = response.parts() + assert len(parts) == 2 + + parts0 = parts[0] + assert parts0.status_code == 202 + assert parts0.headers['x-ms-fun'] == 'true' + + parts1 = parts[1] + assert parts1.status_code == 404 + assert parts1.headers['x-ms-fun'] == 'true' + +@pytest.mark.skipif(sys.version_info < (3, 0), reason="Multipart serialization not supported on 2.7") +def test_response_parts(port): + old_request = _create_multiapart_request(PipelineTransportHttpRequest) + new_request = _create_multiapart_request(RestHttpRequest) + + old_response = TestRestClient(port).send_request(old_request, stream=True) + new_response = TestRestClient(port).send_request(new_request, stream=True) + _test_parts(old_response) + _test_parts(new_response) diff --git a/sdk/core/azure-core/tests/test_retry_policy.py b/sdk/core/azure-core/tests/test_retry_policy.py index 8e989aed487c..994841b52d22 100644 --- a/sdk/core/azure-core/tests/test_retry_policy.py +++ b/sdk/core/azure-core/tests/test_retry_policy.py @@ -23,7 +23,6 @@ ) from azure.core.pipeline import Pipeline, PipelineResponse from azure.core.pipeline.transport import ( - HttpResponse, HttpTransport, ) import tempfile @@ -34,7 +33,7 @@ from unittest.mock import Mock except ImportError: from mock import Mock -from utils import HTTP_REQUESTS +from utils import HTTP_REQUESTS, request_and_responses_product, HTTP_RESPONSES, create_http_response def test_retry_code_class_variables(): @@ -63,11 +62,11 @@ def test_retry_types(): backoff_time = retry_policy.get_backoff_time(settings) assert backoff_time == 4 -@pytest.mark.parametrize("retry_after_input,http_request", product(['0', '800', '1000', '1200'], HTTP_REQUESTS)) -def test_retry_after(retry_after_input, http_request): +@pytest.mark.parametrize("retry_after_input,http_request,http_response", product(['0', '800', '1000', '1200'], HTTP_REQUESTS, HTTP_RESPONSES)) +def test_retry_after(retry_after_input, http_request, http_response): retry_policy = RetryPolicy() request = http_request("GET", "http://localhost") - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.headers["retry-after-ms"] = retry_after_input pipeline_response = PipelineResponse(request, response, None) retry_after = retry_policy.get_retry_after(pipeline_response) @@ -81,11 +80,11 @@ def test_retry_after(retry_after_input, http_request): retry_after = retry_policy.get_retry_after(pipeline_response) assert retry_after == float(retry_after_input) -@pytest.mark.parametrize("retry_after_input,http_request", product(['0', '800', '1000', '1200'], HTTP_REQUESTS)) -def test_x_ms_retry_after(retry_after_input, http_request): +@pytest.mark.parametrize("retry_after_input,http_request,http_response", product(['0', '800', '1000', '1200'], HTTP_REQUESTS, HTTP_RESPONSES)) +def test_x_ms_retry_after(retry_after_input, http_request, http_response): retry_policy = RetryPolicy() request = http_request("GET", "http://localhost") - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.headers["x-ms-retry-after-ms"] = retry_after_input pipeline_response = PipelineResponse(request, response, None) retry_after = retry_policy.get_retry_after(pipeline_response) @@ -99,8 +98,8 @@ def test_x_ms_retry_after(retry_after_input, http_request): retry_after = retry_policy.get_retry_after(pipeline_response) assert retry_after == float(retry_after_input) -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_retry_on_429(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) +def test_retry_on_429(http_request, http_response): class MockTransport(HttpTransport): def __init__(self): self._count = 0 @@ -113,7 +112,7 @@ def open(self): def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineResponse self._count += 1 - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.status_code = 429 return response @@ -124,8 +123,8 @@ def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineRe pipeline.run(http_request) assert transport._count == 2 -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_no_retry_on_201(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) +def test_no_retry_on_201(http_request, http_response): class MockTransport(HttpTransport): def __init__(self): self._count = 0 @@ -138,7 +137,7 @@ def open(self): def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineResponse self._count += 1 - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.status_code = 201 headers = {"Retry-After": "1"} response.headers = headers @@ -151,8 +150,8 @@ def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineRe pipeline.run(http_request) assert transport._count == 1 -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_retry_seekable_stream(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) +def test_retry_seekable_stream(http_request, http_response): class MockTransport(HttpTransport): def __init__(self): self._first = True @@ -170,7 +169,7 @@ def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineRe raise AzureError('fail on first') position = request.body.tell() assert position == 0 - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.status_code = 400 return response @@ -181,8 +180,8 @@ def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineRe pipeline = Pipeline(MockTransport(), [http_retry]) pipeline.run(http_request) -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_retry_seekable_file(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) +def test_retry_seekable_file(http_request, http_response): class MockTransport(HttpTransport): def __init__(self): self._first = True @@ -206,7 +205,7 @@ def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineRe if name and body and hasattr(body, 'read'): position = body.tell() assert not position - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.status_code = 400 return response @@ -246,14 +245,14 @@ def send(request, **kwargs): with pytest.raises(ServiceResponseTimeoutError): response = pipeline.run(http_request("GET", "http://localhost/")) -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_timeout_defaults(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) +def test_timeout_defaults(http_request, http_response): """When "timeout" is not set, the policy should not override the transport's timeout configuration""" def send(request, **kwargs): for arg in ("connection_timeout", "read_timeout"): assert arg not in kwargs, "policy should defer to transport configuration when not given a timeout" - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.status_code = 200 return response diff --git a/sdk/core/azure-core/tests/test_stream_generator.py b/sdk/core/azure-core/tests/test_stream_generator.py index 4c9f5190e9c6..d572a154d936 100644 --- a/sdk/core/azure-core/tests/test_stream_generator.py +++ b/sdk/core/azure-core/tests/test_stream_generator.py @@ -4,10 +4,8 @@ # ------------------------------------ import requests from azure.core.pipeline.transport import ( - HttpResponse, HttpTransport, RequestsTransport, - RequestsTransportResponse, ) from azure.core.pipeline import Pipeline, PipelineResponse from azure.core.pipeline.transport._requests_basic import StreamDownloadGenerator @@ -16,10 +14,10 @@ except ImportError: import mock import pytest -from utils import HTTP_REQUESTS +from utils import HTTP_RESPONSES, REQUESTS_TRANSPORT_RESPONSES, create_http_response, create_transport_response, request_and_responses_product -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_connection_error_response(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) +def test_connection_error_response(http_request, http_response): class MockTransport(HttpTransport): def __init__(self): self._count = 0 @@ -33,7 +31,7 @@ def open(self): def send(self, request, **kwargs): request = http_request('GET', 'http://localhost/') - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.status_code = 200 return response @@ -61,14 +59,15 @@ def close(self): http_request = http_request('GET', 'http://localhost/') pipeline = Pipeline(MockTransport()) - http_response = HttpResponse(http_request, None) + http_response = create_http_response(http_response, http_request, None) http_response.internal_response = MockInternalResponse() stream = StreamDownloadGenerator(pipeline, http_response, decompress=False) with mock.patch('time.sleep', return_value=None): with pytest.raises(requests.exceptions.ConnectionError): stream.__next__() -def test_response_streaming_error_behavior(): +@pytest.mark.parametrize("http_response", REQUESTS_TRANSPORT_RESPONSES) +def test_response_streaming_error_behavior(http_response): # Test to reproduce https://github.com/Azure/azure-sdk-for-python/issues/16723 block_size = 103 total_response_size = 500 @@ -105,7 +104,8 @@ def close(self): s = FakeStreamWithConnectionError() req_response.raw = FakeStreamWithConnectionError() - response = RequestsTransportResponse( + response = create_transport_response( + http_response, req_request, req_response, block_size, diff --git a/sdk/core/azure-core/tests/test_tracing_policy.py b/sdk/core/azure-core/tests/test_tracing_policy.py index ef46a53197e3..6d3524a3c639 100644 --- a/sdk/core/azure-core/tests/test_tracing_policy.py +++ b/sdk/core/azure-core/tests/test_tracing_policy.py @@ -7,12 +7,11 @@ from azure.core.pipeline import PipelineResponse, PipelineRequest, PipelineContext from azure.core.pipeline.policies import DistributedTracingPolicy, UserAgentPolicy -from azure.core.pipeline.transport import HttpResponse from azure.core.settings import settings from tracing_common import FakeSpan import time import pytest -from utils import HTTP_REQUESTS +from utils import HTTP_RESPONSES, create_http_response, request_and_responses_product try: from unittest import mock @@ -20,8 +19,8 @@ import mock -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_distributed_tracing_policy_solo(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) +def test_distributed_tracing_policy_solo(http_request, http_response): """Test policy with no other policy and happy path""" settings.tracing_implementation.set_value(FakeSpan) with FakeSpan(name="parent") as root_span: @@ -33,7 +32,7 @@ def test_distributed_tracing_policy_solo(http_request): pipeline_request = PipelineRequest(request, PipelineContext(None)) policy.on_request(pipeline_request) - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.headers = request.headers response.status_code = 202 response.headers["x-ms-request-id"] = "some request id" @@ -71,8 +70,8 @@ def test_distributed_tracing_policy_solo(http_request): assert network_span.attributes.get("http.status_code") == 504 -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_distributed_tracing_policy_attributes(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) +def test_distributed_tracing_policy_attributes(http_request, http_response): """Test policy with no other policy and happy path""" settings.tracing_implementation.set_value(FakeSpan) with FakeSpan(name="parent") as root_span: @@ -85,7 +84,7 @@ def test_distributed_tracing_policy_attributes(http_request): pipeline_request = PipelineRequest(request, PipelineContext(None)) policy.on_request(pipeline_request) - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.headers = request.headers response.status_code = 202 @@ -96,8 +95,8 @@ def test_distributed_tracing_policy_attributes(http_request): assert network_span.attributes.get("myattr") == "myvalue" -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_distributed_tracing_policy_badurl(caplog, http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) +def test_distributed_tracing_policy_badurl(caplog, http_request, http_response): """Test policy with a bad url that will throw, and be sure policy ignores it""" settings.tracing_implementation.set_value(FakeSpan) with FakeSpan(name="parent") as root_span: @@ -111,7 +110,7 @@ def test_distributed_tracing_policy_badurl(caplog, http_request): policy.on_request(pipeline_request) assert "Unable to start network span" in caplog.text - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.headers = request.headers response.status_code = 202 response.headers["x-ms-request-id"] = "some request id" @@ -130,8 +129,8 @@ def test_distributed_tracing_policy_badurl(caplog, http_request): assert len(root_span.children) == 0 -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_distributed_tracing_policy_with_user_agent(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) +def test_distributed_tracing_policy_with_user_agent(http_request, http_response): """Test policy working with user agent.""" settings.tracing_implementation.set_value(FakeSpan) with mock.patch.dict('os.environ', {"AZURE_HTTP_USER_AGENT": "mytools"}): @@ -147,7 +146,7 @@ def test_distributed_tracing_policy_with_user_agent(http_request): user_agent.on_request(pipeline_request) policy.on_request(pipeline_request) - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.headers = request.headers response.status_code = 202 response.headers["x-ms-request-id"] = "some request id" @@ -189,8 +188,8 @@ def test_distributed_tracing_policy_with_user_agent(http_request): assert network_span.status == 'Transport trouble' -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_span_namer(http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) +def test_span_namer(http_request, http_response): settings.tracing_implementation.set_value(FakeSpan) with FakeSpan(name="parent") as root_span: @@ -205,7 +204,7 @@ def fixed_namer(http_request): policy.on_request(pipeline_request) - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.headers = request.headers response.status_code = 202 @@ -219,7 +218,7 @@ def operation_namer(http_request): policy.on_request(pipeline_request) - response = HttpResponse(request, None) + response = create_http_response(http_response, request, None) response.headers = request.headers response.status_code = 202 diff --git a/sdk/core/azure-core/tests/test_universal_pipeline.py b/sdk/core/azure-core/tests/test_universal_pipeline.py index 9d532608228b..b13e5b019708 100644 --- a/sdk/core/azure-core/tests/test_universal_pipeline.py +++ b/sdk/core/azure-core/tests/test_universal_pipeline.py @@ -41,10 +41,6 @@ PipelineRequest, PipelineContext ) -from azure.core.pipeline.transport import ( - HttpResponse, - RequestsTransportResponse, -) from azure.core.pipeline.policies import ( NetworkTraceLoggingPolicy, @@ -53,7 +49,8 @@ RetryPolicy, HTTPPolicy, ) -from utils import HTTP_REQUESTS, create_http_request +from utils import HTTP_REQUESTS, create_http_request, HTTP_RESPONSES, REQUESTS_TRANSPORT_RESPONSES, create_http_response, create_transport_response, request_and_responses_product +from azure.core.pipeline._tools import is_rest def test_pipeline_context(): kwargs={ @@ -115,12 +112,12 @@ def __deepcopy__(self, memodict={}): assert request_history.http_request.method == request.method @mock.patch('azure.core.pipeline.policies._universal._LOGGER') -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_no_log(mock_http_logger, http_request): +@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) +def test_no_log(mock_http_logger, http_request, http_response): universal_request = http_request('GET', 'http://localhost/') request = PipelineRequest(universal_request, PipelineContext(None)) http_logger = NetworkTraceLoggingPolicy() - response = PipelineResponse(request, HttpResponse(universal_request, None), request.context) + response = PipelineResponse(request, create_http_response(http_response, universal_request, None), request.context) # By default, no log handler for HTTP http_logger.on_request(request) @@ -192,22 +189,45 @@ def send(*args): with pytest.raises(AzureError): pipeline.run(http_request('GET', url='https://foo.bar')) -@pytest.mark.parametrize("http_request", HTTP_REQUESTS) -def test_raw_deserializer(http_request): +@pytest.mark.parametrize("http_request,http_response,requests_transport_response", request_and_responses_product(HTTP_RESPONSES, REQUESTS_TRANSPORT_RESPONSES)) +def test_raw_deserializer(http_request, http_response, requests_transport_response): raw_deserializer = ContentDecodePolicy() context = PipelineContext(None, stream=False) universal_request = http_request('GET', 'http://localhost/') request = PipelineRequest(universal_request, context) def build_response(body, content_type=None): - class MockResponse(HttpResponse): - def __init__(self, body, content_type): - super(MockResponse, self).__init__(None, None) - self._body = body - self.content_type = content_type - - def body(self): - return self._body + if is_rest(http_response): + class MockResponse(http_response): + def __init__(self, body, content_type): + super(MockResponse, self).__init__( + request=None, + internal_response=None, + status_code=400, + reason="Bad Request", + content_type="application/json", + headers={}, + stream_download_generator=None, + ) + self._body = body + self.content_type = content_type + + def body(self): + return self._body + + def read(self): + self._content = self._body + return self.content + + else: + class MockResponse(http_response): + def __init__(self, body, content_type): + super(MockResponse, self).__init__(None, None) + self._body = body + self.content_type = content_type + + def body(self): + return self._body return PipelineResponse(request, MockResponse(body, content_type), context) @@ -292,7 +312,7 @@ def body(self): req_response.headers["content-type"] = "application/json" req_response._content = b'{"success": true}' req_response._content_consumed = True - response = PipelineResponse(None, RequestsTransportResponse(None, req_response), PipelineContext(None, stream=False)) + response = PipelineResponse(None, create_transport_response(requests_transport_response, None, req_response), PipelineContext(None, stream=False)) raw_deserializer.on_response(request, response) result = response.context["deserialized_data"] diff --git a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/multipart.py b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/multipart.py index 236496673a2f..9be44121d2d4 100644 --- a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/multipart.py +++ b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/multipart.py @@ -86,3 +86,34 @@ def non_seekable_filelike(): else: return Response(status=400) return Response(status=200) + +@multipart_api.route('/request', methods=["POST"]) +def multipart_request(): + body_as_str = ( + "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" + "Content-Type: application/http\r\n" + "Content-ID: 0\r\n" + "\r\n" + "HTTP/1.1 202 Accepted\r\n" + "x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + "x-ms-version: 2018-11-09\r\n" + "\r\n" + "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" + "Content-Type: application/http\r\n" + "Content-ID: 2\r\n" + "\r\n" + "HTTP/1.1 404 The specified blob does not exist.\r\n" + "x-ms-error-code: BlobNotFound\r\n" + "x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e2852\r\n" + "x-ms-version: 2018-11-09\r\n" + "Content-Length: 216\r\n" + "Content-Type: application/xml\r\n" + "\r\n" + '\r\n' + "BlobNotFoundThe specified blob does not exist.\r\n" + "RequestId:778fdc83-801e-0000-62ff-0334671e2852\r\n" + "Time:2018-06-14T16:46:54.6040685Z\r\n" + "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--" + ) + return Response(body_as_str.encode('ascii'), content_type="multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed") + diff --git a/sdk/core/azure-core/tests/utils.py b/sdk/core/azure-core/tests/utils.py index f025a124d745..a70c5db98243 100644 --- a/sdk/core/azure-core/tests/utils.py +++ b/sdk/core/azure-core/tests/utils.py @@ -4,17 +4,62 @@ # license information. # ------------------------------------------------------------------------- import pytest +import types ############################## LISTS USED TO PARAMETERIZE TESTS ############################## from azure.core.rest import HttpRequest as RestHttpRequest from azure.core.pipeline.transport import HttpRequest as PipelineTransportHttpRequest - +from azure.core.pipeline._tools import is_rest HTTP_REQUESTS = [PipelineTransportHttpRequest, RestHttpRequest] +REQUESTS_TRANSPORT_RESPONSES = [] + +from azure.core.pipeline.transport import HttpResponse as PipelineTransportHttpResponse +from azure.core.rest._http_response_impl import HttpResponseImpl as RestHttpResponse +HTTP_RESPONSES = [PipelineTransportHttpResponse, RestHttpResponse] + +ASYNC_HTTP_RESPONSES = [] + +try: + from azure.core.pipeline.transport import AsyncHttpResponse as PipelineTransportAsyncHttpResponse + from azure.core.rest._http_response_impl_async import AsyncHttpResponseImpl as RestAsyncHttpResponse + ASYNC_HTTP_RESPONSES = [PipelineTransportAsyncHttpResponse, RestAsyncHttpResponse] +except (ImportError, SyntaxError): + pass + +try: + from azure.core.pipeline.transport import RequestsTransportResponse as PipelineTransportRequestsTransportResponse + from azure.core.rest._requests_basic import RestRequestsTransportResponse + REQUESTS_TRANSPORT_RESPONSES = [PipelineTransportRequestsTransportResponse, RestRequestsTransportResponse] +except ImportError: + pass + +from azure.core.pipeline.transport._base import HttpClientTransportResponse as PipelineTransportHttpClientTransportResponse +from azure.core.rest._http_response_impl import RestHttpClientTransportResponse +HTTP_CLIENT_TRANSPORT_RESPONSES = [PipelineTransportHttpClientTransportResponse, RestHttpClientTransportResponse] + +ASYNCIO_REQUESTS_TRANSPORT_RESPONSES = [] +try: + from azure.core.pipeline.transport import AsyncioRequestsTransportResponse as PipelineTransportAsyncioRequestsTransportResponse + from azure.core.rest._requests_asyncio import RestAsyncioRequestsTransportResponse + ASYNCIO_REQUESTS_TRANSPORT_RESPONSES = [PipelineTransportAsyncioRequestsTransportResponse, RestAsyncioRequestsTransportResponse] +except (ImportError, SyntaxError): + pass + +AIOHTTP_TRANSPORT_RESPONSES = [] + +try: + from azure.core.pipeline.transport import AioHttpTransportResponse as PipelineTransportAioHttpTransportResponse + from azure.core.rest._aiohttp import RestAioHttpTransportResponse + AIOHTTP_TRANSPORT_RESPONSES = [PipelineTransportAioHttpTransportResponse, RestAioHttpTransportResponse] +except (ImportError, SyntaxError): + pass ############################## HELPER FUNCTIONS ############################## -def is_rest(http_request): - return hasattr(http_request, "content") +def request_and_responses_product(*args): + pipeline_transport = tuple([PipelineTransportHttpRequest]) + tuple(arg[0] for arg in args) + rest = tuple([RestHttpRequest]) + tuple(arg[1] for arg in args) + return [pipeline_transport, rest] def create_http_request(http_request, *args, **kwargs): if hasattr(http_request, "content"): @@ -42,26 +87,47 @@ def create_http_request(http_request, *args, **kwargs): ) return http_request(*args, **kwargs) -def readonly_checks(response): - assert isinstance(response.request, RestHttpRequest) - with pytest.raises(AttributeError): - response.request = None +def create_transport_response(http_response, *args, **kwargs): + # this creates transport-specific responses, + # like requests responses / aiohttp responses + if is_rest(http_response): + block_size = args[2] if len(args) > 2 else None + return http_response( + request=args[0], + internal_response=args[1], + block_size=block_size, + **kwargs + ) + return http_response(*args, **kwargs) - assert isinstance(response.status_code, int) - with pytest.raises(AttributeError): - response.status_code = 200 +def create_http_response(http_response, *args, **kwargs): + # since the actual http_response object is + # an ABC for our new responses, it's a little more + # complicated creating a pure http response. + # here, we use the HttpResponsImpl, but we still have to pass + # more things to the init, so making a separate operation + if is_rest(http_response): + block_size = args[2] if len(args) > 2 else None + return http_response( + request=args[0], + internal_response=args[1], + block_size=block_size, + status_code=kwargs.pop("status_code", 200), + reason=kwargs.pop("reason", "OK"), + content_type=kwargs.pop("content_type", "application/json"), + headers=kwargs.pop("headers", {}), + stream_download_generator=kwargs.pop("stream_download_generator", None), + **kwargs + ) + return http_response(*args, **kwargs) +def readonly_checks(response, old_response_class): + # though we want these properties to be completely readonly, it doesn't work + # for the backcompat properties + assert isinstance(response.request, RestHttpRequest) + assert isinstance(response.status_code, int) assert response.headers - with pytest.raises(AttributeError): - response.headers = {"hello": "world"} - - assert response.reason == "OK" - with pytest.raises(AttributeError): - response.reason = "Not OK" - assert response.content_type == 'text/html; charset=utf-8' - with pytest.raises(AttributeError): - response.content_type = "bad content type" assert response.is_closed with pytest.raises(AttributeError): @@ -83,3 +149,18 @@ def readonly_checks(response): assert response.content is not None with pytest.raises(AttributeError): response.content = b"bad" + + old_response = old_response_class(response.request, response.internal_response, response.block_size) + for attr in dir(response): + if attr[0] == '_': + # don't care about private variables + continue + if type(getattr(response, attr)) == types.MethodType: + # methods aren't "readonly" + continue + if attr == "encoding": + # encoding is the only settable new attr + continue + if not attr in vars(old_response): + with pytest.raises(AttributeError): + setattr(response, attr, "new_value")