diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py index e3a659f9c867..c86e738c7ef7 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py @@ -31,6 +31,7 @@ validate_and_format_range_headers) from ._shared.response_handlers import return_response_headers, process_storage_error, return_headers_and_deserialized from ._generated import AzureBlobStorage +from ._generated import models as generated_models from ._generated.models import ( # pylint: disable=unused-import DeleteSnapshotsOptionType, BlobHTTPHeaders, @@ -175,6 +176,8 @@ def __init__( self._query_str, credential = self._format_query_string(sas_token, credential, snapshot=self.snapshot) super(BlobClient, self).__init__(parsed_url, service='blob', credential=credential, **kwargs) self._client = AzureBlobStorage(self.url, pipeline=self._pipeline) + if not self._msrest_xml: + self._custom_xml_deserializer(generated_models) default_api_version = self._client._config.version # pylint: disable=protected-access self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py index d277a094921a..cd0f4f3137ac 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py @@ -28,6 +28,7 @@ from ._shared.response_handlers import return_response_headers, process_storage_error, \ parse_to_internal_user_delegation_key from ._generated import AzureBlobStorage +from ._generated import models as generated_models from ._generated.models import StorageServiceProperties, KeyInfo from ._container_client import ContainerClient from ._blob_client import BlobClient @@ -134,6 +135,8 @@ def __init__( self._query_str, credential = self._format_query_string(sas_token, credential) super(BlobServiceClient, self).__init__(parsed_url, service='blob', credential=credential, **kwargs) self._client = AzureBlobStorage(self.url, pipeline=self._pipeline) + if not self._msrest_xml: + self._custom_xml_deserializer(generated_models) default_api_version = self._client._config.version # pylint: disable=protected-access self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access @@ -676,7 +679,7 @@ def get_container_client(self, container): credential=self.credential, api_version=self.api_version, _configuration=self._config, _pipeline=_pipeline, _location_mode=self._location_mode, _hosts=self._hosts, require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - key_resolver_function=self.key_resolver_function) + key_resolver_function=self.key_resolver_function, msrest_xml=self._msrest_xml) def get_blob_client( self, container, # type: Union[ContainerProperties, str] @@ -729,4 +732,4 @@ def get_blob_client( credential=self.credential, api_version=self.api_version, _configuration=self._config, _pipeline=_pipeline, _location_mode=self._location_mode, _hosts=self._hosts, require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - key_resolver_function=self.key_resolver_function) + key_resolver_function=self.key_resolver_function, msrest_xml=self._msrest_xml) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py index b63556ba61bc..d8ebad17a640 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py @@ -34,6 +34,7 @@ return_response_headers, return_headers_and_deserialized) from ._generated import AzureBlobStorage +from ._generated import models as generated_models from ._generated.models import SignedIdentifier from ._deserialize import deserialize_container_properties from ._serialize import get_modify_conditions, get_container_cpk_scope_info, get_api_version, get_access_conditions @@ -156,6 +157,8 @@ def __init__( self._query_str, credential = self._format_query_string(sas_token, credential) super(ContainerClient, self).__init__(parsed_url, service='blob', credential=credential, **kwargs) self._client = AzureBlobStorage(self.url, pipeline=self._pipeline) + if not self._msrest_xml: + self._custom_xml_deserializer(generated_models) default_api_version = self._client._config.version # pylint: disable=protected-access self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access @@ -769,7 +772,41 @@ def list_blobs(self, name_starts_with=None, include=None, **kwargs): timeout=timeout, **kwargs) return ItemPaged( - command, prefix=name_starts_with, results_per_page=results_per_page, + command, + prefix=name_starts_with, + results_per_page=results_per_page, + select=None, + deserializer=self._client._deserialize, # pylint: disable=protected-access + page_iterator_class=BlobPropertiesPaged) + + @distributed_trace + def list_blob_names(self, **kwargs): + # type: (**Any) -> ItemPaged[str] + """Returns a generator to list the names of blobs under the specified container. + The generator will lazily follow the continuation tokens returned by + the service. + + :keyword str name_starts_with: + Filters the results to return only blobs whose names + begin with the specified prefix. + :keyword int timeout: + The timeout parameter is expressed in seconds. + :returns: An iterable (auto-paging) response of blob names as strings. + :rtype: ~azure.core.paging.ItemPaged[str] + """ + name_starts_with = kwargs.pop('name_starts_with', None) + results_per_page = kwargs.pop('results_per_page', None) + timeout = kwargs.pop('timeout', None) + command = functools.partial( + self._client.container.list_blob_flat_segment, + timeout=timeout, + **kwargs) + return ItemPaged( + command, + prefix=name_starts_with, + results_per_page=results_per_page, + select=["name"], + deserializer=self._client._deserialize, # pylint: disable=protected-access page_iterator_class=BlobPropertiesPaged) @distributed_trace @@ -816,6 +853,8 @@ def walk_blobs( command, prefix=name_starts_with, results_per_page=results_per_page, + select=None, + deserializer=self._client._deserialize, # pylint: disable=protected-access delimiter=delimiter) @distributed_trace @@ -1548,4 +1587,4 @@ def get_blob_client( credential=self.credential, api_version=self.api_version, _configuration=self._config, _pipeline=_pipeline, _location_mode=self._location_mode, _hosts=self._hosts, require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - key_resolver_function=self.key_resolver_function) + key_resolver_function=self.key_resolver_function, msrest_xml=self._msrest_xml) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py index bec837429209..f4d4e68cdaa1 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py @@ -1452,12 +1452,12 @@ async def list_blob_flat_segment( response_headers['x-ms-request-id']=self._deserialize('str', response.headers.get('x-ms-request-id')) response_headers['x-ms-version']=self._deserialize('str', response.headers.get('x-ms-version')) response_headers['Date']=self._deserialize('rfc-1123', response.headers.get('Date')) - deserialized = self._deserialize('ListBlobsFlatSegmentResponse', pipeline_response) + #deserialized = self._deserialize('ListBlobsFlatSegmentResponse', pipeline_response) if cls: - return cls(pipeline_response, deserialized, response_headers) + return cls(pipeline_response, None, response_headers) - return deserialized + return None list_blob_flat_segment.metadata = {'url': '/{containerName}'} # type: ignore async def list_blob_hierarchy_segment( @@ -1564,12 +1564,12 @@ async def list_blob_hierarchy_segment( response_headers['x-ms-request-id']=self._deserialize('str', response.headers.get('x-ms-request-id')) response_headers['x-ms-version']=self._deserialize('str', response.headers.get('x-ms-version')) response_headers['Date']=self._deserialize('rfc-1123', response.headers.get('Date')) - deserialized = self._deserialize('ListBlobsHierarchySegmentResponse', pipeline_response) + # deserialized = self._deserialize('ListBlobsHierarchySegmentResponse', pipeline_response) if cls: - return cls(pipeline_response, deserialized, response_headers) + return cls(pipeline_response, None, response_headers) - return deserialized + return None list_blob_hierarchy_segment.metadata = {'url': '/{containerName}'} # type: ignore async def get_account_info( diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py index f01bbc4393fe..018c93984bf2 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py @@ -1471,12 +1471,12 @@ def list_blob_flat_segment( response_headers['x-ms-request-id']=self._deserialize('str', response.headers.get('x-ms-request-id')) response_headers['x-ms-version']=self._deserialize('str', response.headers.get('x-ms-version')) response_headers['Date']=self._deserialize('rfc-1123', response.headers.get('Date')) - deserialized = self._deserialize('ListBlobsFlatSegmentResponse', pipeline_response) + #deserialized = self._deserialize('ListBlobsFlatSegmentResponse', pipeline_response) if cls: - return cls(pipeline_response, deserialized, response_headers) + return cls(pipeline_response, None, response_headers) - return deserialized + return None # deserialized list_blob_flat_segment.metadata = {'url': '/{containerName}'} # type: ignore def list_blob_hierarchy_segment( @@ -1584,12 +1584,12 @@ def list_blob_hierarchy_segment( response_headers['x-ms-request-id']=self._deserialize('str', response.headers.get('x-ms-request-id')) response_headers['x-ms-version']=self._deserialize('str', response.headers.get('x-ms-version')) response_headers['Date']=self._deserialize('rfc-1123', response.headers.get('Date')) - deserialized = self._deserialize('ListBlobsHierarchySegmentResponse', pipeline_response) + # deserialized = self._deserialize('ListBlobsHierarchySegmentResponse', pipeline_response) if cls: - return cls(pipeline_response, deserialized, response_headers) + return cls(pipeline_response, None, response_headers) - return deserialized + return None # deserialized list_blob_hierarchy_segment.metadata = {'url': '/{containerName}'} # type: ignore def get_account_info( diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py index 309d37bd9583..6b4e1e6bd6b4 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py @@ -7,14 +7,70 @@ from azure.core.paging import PageIterator, ItemPaged from azure.core.exceptions import HttpResponseError + from ._deserialize import get_blob_properties_from_generated_code, parse_tags -from ._generated.models import BlobItemInternal, BlobPrefix as GenBlobPrefix, FilterBlobItem +from ._generated.models import FilterBlobItem from ._models import BlobProperties, FilteredBlob from ._shared.models import DictMixin +from ._shared.xml_deserialization import unpack_xml_content from ._shared.response_handlers import return_context_and_deserialized, process_storage_error -class BlobPropertiesPaged(PageIterator): +def deserialize_list_result(pipeline_response, *_): + payload = unpack_xml_content(pipeline_response.http_response) + location = pipeline_response.http_response.location_mode + return location, payload + + +def load_xml_string(element, name): + node = element.find(name) + if node is None or not node.text: + return None + return node.text + + +def load_xml_int(element, name): + node = element.find(name) + if node is None or not node.text: + return None + return int(node.text) + + +def load_xml_bool(element, name): + node = load_xml_string(element, name) + if node and node.lower() == 'true': + return True + return False + + +def load_single_node(element, name): + return element.find(name) + + +def load_many_nodes(element, name, wrapper=None): + if wrapper: + element = load_single_node(element, wrapper) + return list(element.findall(name)) + + +def blob_properties_from_xml(element, select, deserializer): + if not select: + generated = deserializer.deserialize_data(element, 'BlobItemInternal') + return get_blob_properties_from_generated_code(generated) + blob = BlobProperties() + if 'name' in select: + blob.name = load_xml_string(element, 'Name') + if 'deleted' in select: + blob.deleted = load_xml_bool(element, 'Deleted') + if 'snapshot' in select: + blob.snapshot = load_xml_string(element, 'Snapshot') + if 'version' in select: + blob.version_id = load_xml_string(element, 'VersionId') + blob.is_current_version = load_xml_bool(element, 'IsCurrentVersion') + return blob + + +class BlobPropertiesPaged(PageIterator): # pylint: disable=too-many-instance-attributes """An Iterable of Blob properties. :ivar str service_endpoint: The service URL. @@ -49,6 +105,8 @@ def __init__( container=None, prefix=None, results_per_page=None, + select=None, + deserializer=None, continuation_token=None, delimiter=None, location_mode=None): @@ -58,10 +116,12 @@ def __init__( continuation_token=continuation_token or "" ) self._command = command + self._deserializer = deserializer self.service_endpoint = None self.prefix = prefix self.marker = None self.results_per_page = results_per_page + self.select = select self.container = container self.delimiter = delimiter self.current_page = None @@ -73,30 +133,29 @@ def _get_next_cb(self, continuation_token): prefix=self.prefix, marker=continuation_token or None, maxresults=self.results_per_page, - cls=return_context_and_deserialized, + cls=deserialize_list_result, use_location=self.location_mode) except HttpResponseError as error: process_storage_error(error) def _extract_data_cb(self, get_next_return): self.location_mode, self._response = get_next_return - self.service_endpoint = self._response.service_endpoint - self.prefix = self._response.prefix - self.marker = self._response.marker - self.results_per_page = self._response.max_results - self.container = self._response.container_name - self.current_page = [self._build_item(item) for item in self._response.segment.blob_items] + self.service_endpoint = self._response.get('ServiceEndpoint') + self.prefix = load_xml_string(self._response, 'Prefix') + self.marker = load_xml_string(self._response, 'Marker') + self.results_per_page = load_xml_int(self._response, 'MaxResults') + self.container = self._response.get('ContainerName') - return self._response.next_marker or None, self.current_page + blobs = load_many_nodes(self._response, 'Blob', wrapper='Blobs') + self.current_page = [self._build_item(blob) for blob in blobs] + + next_marker = load_xml_string(self._response, 'NextMarker') + return next_marker or None, self.current_page def _build_item(self, item): - if isinstance(item, BlobProperties): - return item - if isinstance(item, BlobItemInternal): - blob = get_blob_properties_from_generated_code(item) # pylint: disable=protected-access - blob.container = self.container - return blob - return item + blob = blob_properties_from_xml(item, self.select, self._deserializer) + blob.container = self.container + return blob class BlobPrefixPaged(BlobPropertiesPaged): @@ -106,22 +165,26 @@ def __init__(self, *args, **kwargs): def _extract_data_cb(self, get_next_return): continuation_token, _ = super(BlobPrefixPaged, self)._extract_data_cb(get_next_return) - self.current_page = self._response.segment.blob_prefixes + self._response.segment.blob_items - self.current_page = [self._build_item(item) for item in self.current_page] - self.delimiter = self._response.delimiter + + blob_prefixes = load_many_nodes(self._response, 'BlobPrefix', wrapper='Blobs') + blob_prefixes = [self._build_prefix(blob) for blob in blob_prefixes] + + self.current_page = blob_prefixes + self.current_page + self.delimiter = load_xml_string(self._response, 'Delimiter') return continuation_token, self.current_page - def _build_item(self, item): - item = super(BlobPrefixPaged, self)._build_item(item) - if isinstance(item, GenBlobPrefix): - return BlobPrefix( - self._command, - container=self.container, - prefix=item.name, - results_per_page=self.results_per_page, - location_mode=self.location_mode) - return item + def _build_prefix(self, item): + return BlobPrefix( + self._command, + container=self.container, + prefix=load_xml_string(item, 'Name'), + results_per_page=self.results_per_page, + location_mode=self.location_mode, + select=self.select, + deserializer=self._deserializer, + delimiter=self.delimiter + ) class BlobPrefix(ItemPaged, DictMixin): diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py index a2efa2170228..955a8073f5db 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py @@ -35,6 +35,7 @@ AzureSasCredentialPolicy ) +from .xml_deserialization import Deserializer from .constants import STORAGE_OAUTH_SCOPE, SERVICE_HOST_BASE, CONNECTION_TIMEOUT, READ_TIMEOUT from .models import LocationMode from .authentication import SharedKeyCredentialPolicy @@ -74,6 +75,7 @@ def __init__( # type: (...) -> None self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY) self._hosts = kwargs.get("_hosts") + self._msrest_xml = kwargs.get('msrest_xml', False) self.scheme = parsed_url.scheme if service not in ["blob", "queue", "file-share", "dfs"]: @@ -199,6 +201,20 @@ def api_version(self): """ return self._client._config.version # pylint: disable=protected-access + def _custom_xml_deserializer(self, generated_models): + """Reset the deserializer on the generated client to be Storage implementation""" + # pylint: disable=protected-access + client_models = {k: v for k, v in generated_models.__dict__.items() if isinstance(v, type)} + custom_deserialize = Deserializer(client_models) + self._client._deserialize = custom_deserialize + self._client.service._deserialize = custom_deserialize + self._client.container._deserialize = custom_deserialize + self._client.directory._deserialize = custom_deserialize + self._client.blob._deserialize = custom_deserialize + self._client.page_blob._deserialize = custom_deserialize + self._client.append_blob._deserialize = custom_deserialize + self._client.block_blob._deserialize = custom_deserialize + def _format_query_string(self, sas_token, credential, snapshot=None, share_snapshot=None): query_str = "?" if snapshot: @@ -237,14 +253,13 @@ def _create_pipeline(self, credential, **kwargs): config.transport = RequestsTransport(**kwargs) policies = [ QueueMessagePolicy(), + config.headers_policy, config.proxy_policy, config.user_agent_policy, StorageContentValidation(), - ContentDecodePolicy(response_encoding="utf-8"), RedirectPolicy(**kwargs), StorageHosts(hosts=self._hosts, **kwargs), config.retry_policy, - config.headers_policy, StorageRequestHook(**kwargs), self._credential_policy, config.logging_policy, @@ -252,6 +267,8 @@ def _create_pipeline(self, credential, **kwargs): DistributedTracingPolicy(**kwargs), HttpLoggingPolicy(**kwargs) ] + if self._msrest_xml: + policies.insert(5, ContentDecodePolicy(response_encoding="utf-8")) if kwargs.get("_additional_pipeline_policies"): policies = policies + kwargs.get("_additional_pipeline_policies") return config, Pipeline(config.transport, policies=policies) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py index 3e619c90fd71..5deca436299b 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py @@ -15,12 +15,12 @@ from azure.core.async_paging import AsyncList from azure.core.exceptions import HttpResponseError from azure.core.pipeline.policies import ( - ContentDecodePolicy, AsyncBearerTokenCredentialPolicy, AsyncRedirectPolicy, DistributedTracingPolicy, HttpLoggingPolicy, AzureSasCredentialPolicy, + ContentDecodePolicy ) from azure.core.pipeline.transport import AsyncHttpTransport @@ -97,7 +97,6 @@ def _create_pipeline(self, credential, **kwargs): StorageContentValidation(), StorageRequestHook(**kwargs), self._credential_policy, - ContentDecodePolicy(response_encoding="utf-8"), AsyncRedirectPolicy(**kwargs), StorageHosts(hosts=self._hosts, **kwargs), # type: ignore config.retry_policy, @@ -106,6 +105,8 @@ def _create_pipeline(self, credential, **kwargs): DistributedTracingPolicy(**kwargs), HttpLoggingPolicy(**kwargs), ] + if self._msrest_xml: + policies.insert(5, ContentDecodePolicy(response_encoding="utf-8")) if kwargs.get("_additional_pipeline_policies"): policies = policies + kwargs.get("_additional_pipeline_policies") return config, AsyncPipeline(config.transport, policies=policies) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py new file mode 100644 index 000000000000..20f53f028354 --- /dev/null +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py @@ -0,0 +1,587 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- + +from base64 import b64decode +import datetime +import decimal +import email +from enum import Enum +import logging +import re +import xml.etree.ElementTree as ET + +import isodate +from msrest.exceptions import DeserializationError, raise_with_traceback +from msrest.serialization import ( + TZ_UTC, + _FixedOffset +) + +from azure.core.exceptions import DecodeError + + +try: + basestring # pylint: disable=pointless-statement + unicode_str = unicode # type: ignore +except NameError: + basestring = str # type: ignore + unicode_str = str # type: ignore + +_LOGGER = logging.getLogger(__name__) +_valid_date = re.compile( + r'\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}' + r'\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?') + +try: + _long_type = long # type: ignore +except NameError: + _long_type = int + + +def unpack_xml_content(response_data, **kwargs): + """Extract the correct structure for deserialization. + + If raw_data is a PipelineResponse, try to extract the result of RawDeserializer. + if we can't, raise. Your Pipeline should have a RawDeserializer. + + If not a pipeline response and raw_data is bytes or string, use content-type + to decode it. If no content-type, try JSON. + + If raw_data is something else, bypass all logic and return it directly. + + :param raw_data: Data to be processed. + :param content_type: How to parse if raw_data is a string/bytes. + :raises UnicodeDecodeError: If bytes is not UTF8 + """ + try: + return ET.fromstring(response_data.body()) # nosec + except ET.ParseError: + _LOGGER.critical("Response body invalid XML") + raise_with_traceback(DecodeError, message="XML is invalid", response=response_data, **kwargs) + + +def deserialize_bytearray(attr, *_): + """Deserialize string into bytearray. + + :param str attr: response string to be deserialized. + :rtype: bytearray + :raises: TypeError if string format invalid. + """ + return bytearray(b64decode(attr)) + + +def deserialize_base64(attr, *_): + """Deserialize base64 encoded string into string. + + :param str attr: response string to be deserialized. + :rtype: bytearray + :raises: TypeError if string format invalid. + """ + padding = '=' * (3 - (len(attr) + 3) % 4) + attr = attr + padding + encoded = attr.replace('-', '+').replace('_', '/') + return b64decode(encoded) + + +def deserialize_decimal(attr, *_): + """Deserialize string into Decimal object. + + :param str attr: response string to be deserialized. + :rtype: Decimal + :raises: DeserializationError if string format invalid. + """ + try: + return decimal.Decimal(attr) + except decimal.DecimalException as err: + msg = "Invalid decimal {}".format(attr) + raise_with_traceback(DeserializationError, msg, err) + + +def deserialize_bool(attr, *_): + """Deserialize string into bool. + + :param str attr: response string to be deserialized. + :rtype: bool + :raises: TypeError if string format is not valid. + """ + if attr in [True, False, 1, 0]: + return bool(attr) + if isinstance(attr, basestring): + if attr.lower() in ['true', '1']: + return True + if attr.lower() in ['false', '0']: + return False + raise TypeError("Invalid boolean value: {}".format(attr)) + + +def deserialize_int(attr, *_): + """Deserialize string into int. + + :param str attr: response string to be deserialized. + :rtype: int + :raises: ValueError or TypeError if string format invalid. + """ + return int(attr) + + +def deserialize_float(attr, *_): + """Deserialize string into float. + + :param str attr: response string to be deserialized. + :rtype: float + :raises: ValueError if string format invalid. + """ + return float(attr) + + +def deserialize_long(attr, *_): + """Deserialize string into long (Py2) or int (Py3). + + :param str attr: response string to be deserialized. + :rtype: long or int + :raises: ValueError if string format invalid. + """ + return _long_type(attr) + + +def deserialize_duration(attr, *_): + """Deserialize ISO-8601 formatted string into TimeDelta object. + + :param str attr: response string to be deserialized. + :rtype: TimeDelta + :raises: DeserializationError if string format invalid. + """ + try: + duration = isodate.parse_duration(attr) + except(ValueError, OverflowError, AttributeError) as err: + msg = "Cannot deserialize duration object." + raise_with_traceback(DeserializationError, msg, err) + else: + return duration + + +def deserialize_date(attr, *_): + """Deserialize ISO-8601 formatted string into Date object. + + :param str attr: response string to be deserialized. + :rtype: Date + :raises: DeserializationError if string format invalid. + """ + if re.search(r"[^\W\d_]", attr, re.I + re.U): + raise DeserializationError("Date must have only digits and -. Received: %s" % attr) + # This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception. + return isodate.parse_date(attr, defaultmonth=None, defaultday=None) + + +def deserialize_time(attr, *_): + """Deserialize ISO-8601 formatted string into time object. + + :param str attr: response string to be deserialized. + :rtype: datetime.time + :raises: DeserializationError if string format invalid. + """ + if re.search(r"[^\W\d_]", attr, re.I + re.U): + raise DeserializationError("Date must have only digits and -. Received: %s" % attr) + return isodate.parse_time(attr) + + +def deserialize_rfc(attr, *_): + """Deserialize RFC-1123 formatted string into Datetime object. + + :param str attr: response string to be deserialized. + :rtype: Datetime + :raises: DeserializationError if string format invalid. + """ + try: + parsed_date = email.utils.parsedate_tz(attr) + date_obj = datetime.datetime( + *parsed_date[:6], + tzinfo=_FixedOffset(datetime.timedelta(minutes=(parsed_date[9] or 0)/60)) + ) + if not date_obj.tzinfo: + date_obj = date_obj.astimezone(tz=TZ_UTC) + except ValueError as err: + msg = "Cannot deserialize to rfc datetime object." + raise_with_traceback(DeserializationError, msg, err) + else: + return date_obj + + +def deserialize_iso(attr, *_): + """Deserialize ISO-8601 formatted string into Datetime object. + + :param str attr: response string to be deserialized. + :rtype: Datetime + :raises: DeserializationError if string format invalid. + """ + try: + attr = attr.upper() + match = _valid_date.match(attr) + if not match: + raise ValueError("Invalid datetime string: " + attr) + + check_decimal = attr.split('.') + if len(check_decimal) > 1: + decimal_str = "" + for digit in check_decimal[1]: + if digit.isdigit(): + decimal_str += digit + else: + break + if len(decimal_str) > 6: + attr = attr.replace(decimal_str, decimal_str[0:6]) + + date_obj = isodate.parse_datetime(attr) + test_utc = date_obj.utctimetuple() + if test_utc.tm_year > 9999 or test_utc.tm_year < 1: + raise OverflowError("Hit max or min date") + except(ValueError, OverflowError, AttributeError) as err: + msg = "Cannot deserialize datetime object." + raise_with_traceback(DeserializationError, msg, err) + else: + return date_obj + + +def deserialize_object(attr, *_): + """Deserialize a generic object. + This will be handled as a dictionary. + + :param dict attr: Dictionary to be deserialized. + :rtype: dict + :raises: TypeError if non-builtin datatype encountered. + """ + # Do no recurse on XML, just return the tree as-is + # TODO: This probably needs work + return attr + + +def deserialize_unix(attr, *_): + """Serialize Datetime object into IntTime format. + This is represented as seconds. + + :param int attr: Object to be serialized. + :rtype: Datetime + :raises: DeserializationError if format invalid + """ + try: + date_obj = datetime.datetime.fromtimestamp(int(attr), TZ_UTC) + except ValueError as err: + msg = "Cannot deserialize to unix datetime object." + raise_with_traceback(DeserializationError, msg, err) + else: + return date_obj + + +def deserialize_unicode(data, *_): + """Preserve unicode objects in Python 2, otherwise return data + as a string. + + :param str data: response string to be deserialized. + :rtype: str or unicode + """ + if data is None: + return "" + # We might be here because we have an enum modeled as string, + # and we try to deserialize a partial dict with enum inside + if isinstance(data, Enum): + return data + + # Consider this is real string + try: + if isinstance(data, unicode): + return data + except NameError: + return str(data) + else: + return str(data) + + +def deserialize_enum(data, enum_obj): + """Deserialize string into enum object. + + If the string is not a valid enum value it will be returned as-is + and a warning will be logged. + + :param str data: Response string to be deserialized. If this value is + None or invalid it will be returned as-is. + :param Enum enum_obj: Enum object to deserialize to. + :rtype: Enum + """ + if isinstance(data, enum_obj) or data is None: + return data + if isinstance(data, Enum): + data = data.value + if isinstance(data, int): + # Workaround. We might consider remove it in the future. + # https://github.com/Azure/azure-rest-api-specs/issues/141 + try: + return list(enum_obj.__members__.values())[data] + except IndexError: + error = "{!r} is not a valid index for enum {!r}" + raise DeserializationError(error.format(data, enum_obj)) + try: + return enum_obj(str(data)) + except ValueError: + for enum_value in enum_obj: + if enum_value.value.lower() == str(data).lower(): + return enum_value + # We don't fail anymore for unknown value, we deserialize as a string + _LOGGER.warning("Deserializer is not able to find %s as valid enum in %s", data, enum_obj) + return deserialize_unicode(data) + + +def instantiate_model(response, attrs, additional_properties=None): + """Instantiate a response model passing in deserialized args. + + :param response: The response model class. + :param d_attrs: The deserialized response attributes. + """ + try: + readonly = [k for k, v in response._validation.items() if v.get('readonly')] # pylint:disable=protected-access + const = [k for k, v in response._validation.items() if v.get('constant')] # pylint:disable=protected-access + kwargs = {k: v for k, v in attrs.items() if k not in readonly + const} + response_obj = response(**kwargs) + for attr in readonly: + setattr(response_obj, attr, attrs.get(attr)) + if additional_properties: + response_obj.additional_properties = additional_properties + return response_obj + except Exception as err: + msg = "Unable to deserialize {} into model {}. ".format( + kwargs, response) + raise DeserializationError(msg + str(err)) + + +def multi_xml_key_extractor(attr_desc, data, subtype): + xml_desc = attr_desc.get('xml', {}) + xml_name = xml_desc.get('name', attr_desc['key']) + is_wrapped = xml_desc.get("wrapped", False) + subtype_xml_map = getattr(subtype, "_xml_map", {}) + if is_wrapped: + items_name = xml_name + elif subtype: + items_name = subtype_xml_map.get('name', xml_name) + else: + items_name = xml_desc.get("itemsName", xml_name) + children = data.findall(items_name) + if is_wrapped: + if len(children) == 0: + return None + return list(children[0]) + return children + +def xml_key_extractor(attr_desc, data, subtype): + xml_desc = attr_desc.get('xml', {}) + xml_name = xml_desc.get('name', attr_desc['key']) + + # If it's an attribute, that's simple + if xml_desc.get("attr", False): + return data.get(xml_name) + + # If it's x-ms-text, that's simple too + if xml_desc.get("text", False): + return data.text + + subtype_xml_map = getattr(subtype, "_xml_map", {}) + xml_name = subtype_xml_map.get('name', xml_name) + return data.find(xml_name) + + +class Deserializer(object): + """Response object model deserializer. + + :param dict classes: Class type dictionary for deserializing complex types. + """ + def __init__(self, classes=None): + self.deserialize_type = { + 'str': deserialize_unicode, + 'int': deserialize_int, + 'bool': deserialize_bool, + 'float': deserialize_float, + 'iso-8601': deserialize_iso, + 'rfc-1123': deserialize_rfc, + 'unix-time': deserialize_unix, + 'duration': deserialize_duration, + 'date': deserialize_date, + 'time': deserialize_time, + 'decimal': deserialize_decimal, + 'long': deserialize_long, + 'bytearray': deserialize_bytearray, + 'base64': deserialize_base64, + 'object': deserialize_object, + '[]': self.deserialize_iter, + '{}': self.deserialize_dict + } + + self.dependencies = dict(classes) if classes else {} + + def __call__(self, target_obj, response_data, **kwargs): + """Call the deserializer to process a REST response. + + :param str target_obj: Target data type to deserialize to. + :param requests.Response response_data: REST response object. + :param str content_type: Swagger "produces" if available. + :raises: DeserializationError if deserialization fails. + :return: Deserialized object. + """ + try: + # First, unpack the response if we have one. + response_data = unpack_xml_content(response_data.http_response, **kwargs) + except AttributeError: + pass + if response_data is None: + # No data. Moving on. + return None + #return self._deserialize(target_obj, response_data) + return self.deserialize_data(response_data, target_obj) + + def failsafe_deserialize(self, target_obj, data, content_type=None): + """Ignores any errors encountered in deserialization, + and falls back to not deserializing the object. Recommended + for use in error deserialization, as we want to return the + HttpResponseError to users, and not have them deal with + a deserialization error. + + :param str target_obj: The target object type to deserialize to. + :param str/dict data: The response data to deseralize. + :param str content_type: Swagger "produces" if available. + """ + try: + return self(target_obj, data, content_type=content_type) + except: # pylint: disable=bare-except + _LOGGER.warning( + "Ran into a deserialization error. Ignoring since this is failsafe deserialization", + exc_info=True + ) + return None + + def _deserialize(self, target_obj, data): + """Call the deserializer on a model. + + Data needs to be already deserialized as JSON or XML ElementTree + + :param str target_obj: Target data type to deserialize to. + :param object data: Object to deserialize. + :raises: DeserializationError if deserialization fails. + :return: Deserialized object. + """ + try: + model_type = self.dependencies[target_obj] + if issubclass(model_type, Enum): + return deserialize_enum(data.text, model_type) + except KeyError: + return self.deserialize_data(data, target_obj) + + if data is None: + return data + try: + attributes = model_type._attribute_map # pylint:disable=protected-access + d_attrs = {} + include_extra_props = False + for attr, attr_desc in attributes.items(): + # Check empty string. If it's not empty, someone has a real "additionalProperties"... + if attr == "additional_properties" and attr_desc["key"] == '': + include_extra_props = True + continue + attr_type = attr_desc["type"] + try: + # TODO: Validate this subtype logic + subtype = self.dependencies[attr_type.strip('[]{}')] + except KeyError: + subtype = None + if attr_type[0] == '[': + raw_value = multi_xml_key_extractor(attr_desc, data, subtype) + else: + raw_value = xml_key_extractor(attr_desc, data, subtype) + value = self.deserialize_data(raw_value, attr_type) + d_attrs[attr] = value + except (AttributeError, TypeError, KeyError) as err: + msg = "Unable to deserialize to object: " + str(target_obj) + raise_with_traceback(DeserializationError, msg, err) + else: + if include_extra_props: + extra = {el.tag: el.text for el in data if el.tag not in d_attrs} + return instantiate_model(model_type, d_attrs, extra) + return instantiate_model(model_type, d_attrs) + + def deserialize_data(self, data, data_type): + """Process data for deserialization according to data type. + + :param str data: The response string to be deserialized. + :param str data_type: The type to deserialize to. + :raises: DeserializationError if deserialization fails. + :return: Deserialized object. + """ + if not data_type or data is None: + return None + try: + xml_data = data.text + except AttributeError: + xml_data = data + + try: + basic_deserialize = self.deserialize_type[data_type] + if not xml_data and data_type != 'str': + return None + return basic_deserialize(xml_data, data_type) + except KeyError: + pass + except (ValueError, TypeError, AttributeError) as err: + msg = "Unable to deserialize response data." + msg += " Data: {}, {}".format(data, data_type) + raise_with_traceback(DeserializationError, msg, err) + + try: + iter_type = data_type[0] + data_type[-1] + if iter_type in self.deserialize_type: + return self.deserialize_type[iter_type](data, data_type[1:-1]) + except (ValueError, TypeError, AttributeError) as err: + msg = "Unable to deserialize response data." + msg += " Data: {}, {}".format(data, data_type) + raise_with_traceback(DeserializationError, msg, err) + else: + return self._deserialize(data_type, data) + + def deserialize_iter(self, attr, iter_type): + """Deserialize an iterable. + + :param list attr: Iterable to be deserialized. + :param str iter_type: The type of object in the iterable. + :rtype: list + """ + return [self.deserialize_data(a, iter_type) for a in list(attr)] + + def deserialize_dict(self, attr, dict_type): + """Deserialize a dictionary. + + :param dict/list attr: Dictionary to be deserialized. Also accepts + a list of key, value pairs. + :param str dict_type: The object type of the items in the dictionary. + :rtype: dict + """ + # Transform value into {"Key": "value"} + attr = {el.tag: el.text for el in attr} + return {k: self.deserialize_data(v, dict_type) for k, v in attr.items()} diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py index d13de28f8711..ec5266183770 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py @@ -21,6 +21,7 @@ from .._deserialize import get_page_ranges_result, parse_tags, deserialize_pipeline_response_into_cls from .._serialize import get_modify_conditions, get_api_version, get_access_conditions from .._generated.aio import AzureBlobStorage +from .._generated import models as generated_models from .._generated.models import CpkInfo from .._deserialize import deserialize_blob_properties from .._blob_client import BlobClient as BlobClientBase @@ -120,6 +121,8 @@ def __init__( credential=credential, **kwargs) self._client = AzureBlobStorage(url=self.url, pipeline=self._pipeline) + if not self._msrest_xml: + self._custom_xml_deserializer(generated_models) default_api_version = self._client._config.version # pylint: disable=protected-access self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_service_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_service_client_async.py index d50661d8e2d7..3a2fd1eb4b14 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_service_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_service_client_async.py @@ -24,6 +24,7 @@ from .._shared.parser import _to_utc_datetime from .._shared.response_handlers import parse_to_internal_user_delegation_key from .._generated.aio import AzureBlobStorage +from .._generated import models as generated_models from .._generated.models import StorageServiceProperties, KeyInfo from .._blob_service_client import BlobServiceClient as BlobServiceClientBase from ._container_client_async import ContainerClient @@ -118,6 +119,8 @@ def __init__( credential=credential, **kwargs) self._client = AzureBlobStorage(url=self.url, pipeline=self._pipeline) + if not self._msrest_xml: + self._custom_xml_deserializer(generated_models) default_api_version = self._client._config.version # pylint: disable=protected-access self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access @@ -619,7 +622,7 @@ def get_container_client(self, container): credential=self.credential, api_version=self.api_version, _configuration=self._config, _pipeline=_pipeline, _location_mode=self._location_mode, _hosts=self._hosts, require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - key_resolver_function=self.key_resolver_function) + key_resolver_function=self.key_resolver_function, msrest_xml=self._msrest_xml) def get_blob_client( self, container, # type: Union[ContainerProperties, str] @@ -674,4 +677,4 @@ def get_blob_client( credential=self.credential, api_version=self.api_version, _configuration=self._config, _pipeline=_pipeline, _location_mode=self._location_mode, _hosts=self._hosts, require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - key_resolver_function=self.key_resolver_function) + key_resolver_function=self.key_resolver_function, msrest_xml=self._msrest_xml) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py index cd0164392ab6..9a919ec6c36b 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py @@ -26,6 +26,7 @@ return_response_headers, return_headers_and_deserialized) from .._generated.aio import AzureBlobStorage +from .._generated import models as generated_models from .._generated.models import SignedIdentifier from .._deserialize import deserialize_container_properties from .._serialize import get_modify_conditions, get_container_cpk_scope_info, get_api_version, get_access_conditions @@ -117,6 +118,8 @@ def __init__( credential=credential, **kwargs) self._client = AzureBlobStorage(url=self.url, pipeline=self._pipeline) + if not self._msrest_xml: + self._custom_xml_deserializer(generated_models) default_api_version = self._client._config.version # pylint: disable=protected-access self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access @@ -633,6 +636,39 @@ def list_blobs(self, name_starts_with=None, include=None, **kwargs): command, prefix=name_starts_with, results_per_page=results_per_page, + select=None, + deserializer=self._client._deserialize, # pylint: disable=protected-access + page_iterator_class=BlobPropertiesPaged + ) + + @distributed_trace + def list_blob_names(self, **kwargs): + # type: (**Any) -> AsyncItemPaged[str] + """Returns a generator to list the names of blobs under the specified container. + The generator will lazily follow the continuation tokens returned by + the service. + + :keyword str name_starts_with: + Filters the results to return only blobs whose names + begin with the specified prefix. + :keyword int timeout: + The timeout parameter is expressed in seconds. + :returns: An iterable (auto-paging) response of blob names as strings. + :rtype: ~azure.core.async_paging.AsyncItemPaged[str] + """ + name_starts_with = kwargs.pop('name_starts_with', None) + results_per_page = kwargs.pop('results_per_page', None) + timeout = kwargs.pop('timeout', None) + command = functools.partial( + self._client.container.list_blob_flat_segment, + timeout=timeout, + **kwargs) + return AsyncItemPaged( + command, + prefix=name_starts_with, + results_per_page=results_per_page, + select=["name"], + deserializer=self._client._deserialize, # pylint: disable=protected-access page_iterator_class=BlobPropertiesPaged ) @@ -680,6 +716,8 @@ def walk_blobs( command, prefix=name_starts_with, results_per_page=results_per_page, + select=None, + deserializer=self._client._deserialize, # pylint: disable=protected-access delimiter=delimiter) @distributed_trace_async @@ -1206,4 +1244,4 @@ def get_blob_client( credential=self.credential, api_version=self.api_version, _configuration=self._config, _pipeline=_pipeline, _location_mode=self._location_mode, _hosts=self._hosts, require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - key_resolver_function=self.key_resolver_function) + key_resolver_function=self.key_resolver_function, msrest_xml=self._msrest_xml) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_list_blobs_helper.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_list_blobs_helper.py index 058572fd270d..9a11087d7020 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_list_blobs_helper.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_list_blobs_helper.py @@ -7,14 +7,18 @@ from azure.core.async_paging import AsyncPageIterator, AsyncItemPaged from azure.core.exceptions import HttpResponseError -from .._deserialize import get_blob_properties_from_generated_code -from .._models import BlobProperties -from .._generated.models import BlobItemInternal, BlobPrefix as GenBlobPrefix from .._shared.models import DictMixin -from .._shared.response_handlers import return_context_and_deserialized, process_storage_error +from .._shared.response_handlers import process_storage_error +from .._list_blobs_helper import ( + deserialize_list_result, + load_many_nodes, + load_xml_string, + load_xml_int, + blob_properties_from_xml +) -class BlobPropertiesPaged(AsyncPageIterator): +class BlobPropertiesPaged(AsyncPageIterator): # pylint: disable=too-many-instance-attributes """An Iterable of Blob properties. :ivar str service_endpoint: The service URL. @@ -48,6 +52,8 @@ def __init__( container=None, prefix=None, results_per_page=None, + select=None, + deserializer=None, continuation_token=None, delimiter=None, location_mode=None): @@ -57,10 +63,12 @@ def __init__( continuation_token=continuation_token or "" ) self._command = command + self._deserializer = deserializer self.service_endpoint = None self.prefix = prefix self.marker = None self.results_per_page = results_per_page + self.select = select self.container = container self.delimiter = delimiter self.current_page = None @@ -72,30 +80,29 @@ async def _get_next_cb(self, continuation_token): prefix=self.prefix, marker=continuation_token or None, maxresults=self.results_per_page, - cls=return_context_and_deserialized, + cls=deserialize_list_result, use_location=self.location_mode) except HttpResponseError as error: process_storage_error(error) async def _extract_data_cb(self, get_next_return): self.location_mode, self._response = get_next_return - self.service_endpoint = self._response.service_endpoint - self.prefix = self._response.prefix - self.marker = self._response.marker - self.results_per_page = self._response.max_results - self.container = self._response.container_name - self.current_page = [self._build_item(item) for item in self._response.segment.blob_items] + self.service_endpoint = self._response.get('ServiceEndpoint') + self.prefix = load_xml_string(self._response, 'Prefix') + self.marker = load_xml_string(self._response, 'Marker') + self.results_per_page = load_xml_int(self._response, 'MaxResults') + self.container = self._response.get('ContainerName') - return self._response.next_marker or None, self.current_page + blobs = load_many_nodes(self._response, 'Blob', wrapper='Blobs') + self.current_page = [self._build_item(blob) for blob in blobs] + + next_marker = load_xml_string(self._response, 'NextMarker') + return next_marker or None, self.current_page def _build_item(self, item): - if isinstance(item, BlobProperties): - return item - if isinstance(item, BlobItemInternal): - blob = get_blob_properties_from_generated_code(item) # pylint: disable=protected-access - blob.container = self.container - return blob - return item + blob = blob_properties_from_xml(item, self.select, self._deserializer) + blob.container = self.container + return blob class BlobPrefix(AsyncItemPaged, DictMixin): @@ -144,20 +151,21 @@ def __init__(self, *args, **kwargs): self.name = self.prefix async def _extract_data_cb(self, get_next_return): - continuation_token, _ = await super(BlobPrefixPaged, self)._extract_data_cb(get_next_return) - self.current_page = self._response.segment.blob_prefixes + self._response.segment.blob_items - self.current_page = [self._build_item(item) for item in self.current_page] - self.delimiter = self._response.delimiter - + continuation_token, current_page = await super(BlobPrefixPaged, self)._extract_data_cb(get_next_return) + blob_prefixes = load_many_nodes(self._response, 'BlobPrefix', wrapper='Blobs') + blob_prefixes = [self._build_prefix(blob) for blob in blob_prefixes] + self.current_page = blob_prefixes + current_page + self.delimiter = load_xml_string(self._response, 'Delimiter') return continuation_token, self.current_page - def _build_item(self, item): - item = super(BlobPrefixPaged, self)._build_item(item) - if isinstance(item, GenBlobPrefix): - return BlobPrefix( - self._command, - container=self.container, - prefix=item.name, - results_per_page=self.results_per_page, - location_mode=self.location_mode) - return item + def _build_prefix(self, item): + return BlobPrefix( + self._command, + container=self.container, + prefix=load_xml_string(item, 'Name'), + results_per_page=self.results_per_page, + location_mode=self.location_mode, + select=self.select, + deserializer=self._deserializer, + delimiter=self.delimiter + ) diff --git a/sdk/storage/azure-storage-blob/tests/perfstress_tests/T1_legacy_tests/list_blobs.py b/sdk/storage/azure-storage-blob/tests/perfstress_tests/T1_legacy_tests/list_blobs.py index b3a55bcf23b9..aedc564e0109 100644 --- a/sdk/storage/azure-storage-blob/tests/perfstress_tests/T1_legacy_tests/list_blobs.py +++ b/sdk/storage/azure-storage-blob/tests/perfstress_tests/T1_legacy_tests/list_blobs.py @@ -17,8 +17,12 @@ async def global_setup(self): blob=b"") def run_sync(self): - for _ in self.service_client.list_blobs(container_name=self.container_name): - pass + if self.args.name_only: + for _ in self.service_client.list_blob_names(container_name=self.container_name): + pass + else: + for _ in self.service_client.list_blobs(container_name=self.container_name): + pass async def run_async(self): raise NotImplementedError("Async not supported for legacy T1 tests.") @@ -27,3 +31,4 @@ async def run_async(self): def add_arguments(parser): super(LegacyListBlobsTest, LegacyListBlobsTest).add_arguments(parser) parser.add_argument('-c', '--count', nargs='?', type=int, help='Number of blobs to list. Defaults to 100', default=100) + parser.add_argument('--name-only', action='store_true', help='Return only blob name. Defaults to False', default=False) diff --git a/sdk/storage/azure-storage-blob/tests/perfstress_tests/_test_base.py b/sdk/storage/azure-storage-blob/tests/perfstress_tests/_test_base.py index ca46e67ffccb..8d9cdaf49829 100644 --- a/sdk/storage/azure-storage-blob/tests/perfstress_tests/_test_base.py +++ b/sdk/storage/azure-storage-blob/tests/perfstress_tests/_test_base.py @@ -25,6 +25,7 @@ def __init__(self, arguments): self._client_kwargs['max_single_put_size'] = self.args.max_put_size self._client_kwargs['max_block_size'] = self.args.max_block_size self._client_kwargs['min_large_block_upload_threshold'] = self.args.buffer_threshold + self._client_kwargs['msrest_xml'] = not self.args.no_msrest # self._client_kwargs['api_version'] = '2019-02-02' # Used only for comparison with T1 legacy tests if not _ServiceTest.service_client or self.args.no_client_share: @@ -46,6 +47,7 @@ def add_arguments(parser): parser.add_argument('--max-concurrency', nargs='?', type=int, help='Maximum number of concurrent threads used for data transfer. Defaults to 1', default=1) parser.add_argument('-s', '--size', nargs='?', type=int, help='Size of data to transfer. Default is 10240.', default=10240) parser.add_argument('--no-client-share', action='store_true', help='Create one ServiceClient per test instance. Default is to share a single ServiceClient.', default=False) + parser.add_argument('--no-msrest', action='store_true', help='Do not use the msrest XML derialization pipeline. Defaults to False', default=False) class _ContainerTest(_ServiceTest): diff --git a/sdk/storage/azure-storage-blob/tests/perfstress_tests/list_blobs.py b/sdk/storage/azure-storage-blob/tests/perfstress_tests/list_blobs.py index f5f35a86fff1..65894b044e3f 100644 --- a/sdk/storage/azure-storage-blob/tests/perfstress_tests/list_blobs.py +++ b/sdk/storage/azure-storage-blob/tests/perfstress_tests/list_blobs.py @@ -27,14 +27,23 @@ async def global_setup(self): break def run_sync(self): - for _ in self.container_client.list_blobs(): - pass + if self.args.name_only: + for _ in self.container_client.list_blob_names(): + pass + else: + for _ in self.container_client.list_blobs(): + pass async def run_async(self): - async for _ in self.async_container_client.list_blobs(): - pass + if self.args.name_only: + async for _ in self.async_container_client.list_blob_names(): + pass + else: + async for _ in self.async_container_client.list_blobs(): + pass @staticmethod def add_arguments(parser): super(ListBlobsTest, ListBlobsTest).add_arguments(parser) parser.add_argument('-c', '--count', nargs='?', type=int, help='Number of blobs to list. Defaults to 100', default=100) + parser.add_argument('--name-only', action='store_true', help='Return only blob name. Defaults to False', default=False) diff --git a/sdk/storage/azure-storage-blob/tests/test_blob_service_stats.py b/sdk/storage/azure-storage-blob/tests/test_blob_service_stats.py index 1de16c8a6538..2a36c6b8f9c2 100644 --- a/sdk/storage/azure-storage-blob/tests/test_blob_service_stats.py +++ b/sdk/storage/azure-storage-blob/tests/test_blob_service_stats.py @@ -12,13 +12,13 @@ from _shared.testcase import GlobalStorageAccountPreparer, GlobalResourceGroupPreparer -SERVICE_UNAVAILABLE_RESP_BODY = 'unavailable ' +SERVICE_UNAVAILABLE_RESP_BODY = b'unavailable ' -SERVICE_LIVE_RESP_BODY = 'liveWed, 19 Jan 2021 22:28:43 GMT ' +SERVICE_LIVE_RESP_BODY = b'liveWed, 19 Jan 2021 22:28:43 GMT ' # --Test Class ----------------------------------------------------------------- class ServiceStatsTest(StorageTestCase): @@ -39,11 +39,11 @@ def _assert_stats_unavailable(self, stats): @staticmethod def override_response_body_with_live_status(response): - response.http_response.text = lambda encoding=None: SERVICE_LIVE_RESP_BODY + response.http_response.body = lambda: SERVICE_LIVE_RESP_BODY @staticmethod def override_response_body_with_unavailable_status(response): - response.http_response.text = lambda encoding=None: SERVICE_UNAVAILABLE_RESP_BODY + response.http_response.body = lambda: SERVICE_UNAVAILABLE_RESP_BODY # --Test cases per service --------------------------------------- @GlobalResourceGroupPreparer() diff --git a/sdk/storage/azure-storage-blob/tests/test_blob_service_stats_async.py b/sdk/storage/azure-storage-blob/tests/test_blob_service_stats_async.py index 380fa67b024d..4b545a2eafab 100644 --- a/sdk/storage/azure-storage-blob/tests/test_blob_service_stats_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_blob_service_stats_async.py @@ -15,14 +15,14 @@ from devtools_testutils.storage.aio import AsyncStorageTestCase -SERVICE_UNAVAILABLE_RESP_BODY = 'unavailable ' +SERVICE_UNAVAILABLE_RESP_BODY = b'unavailable ' -SERVICE_LIVE_RESP_BODY = 'liveWed, 19 Jan 2021 22:28:43 GMT ' +SERVICE_LIVE_RESP_BODY = b'liveWed, 19 Jan 2021 22:28:43 GMT ' class AiohttpTestTransport(AioHttpTransport): @@ -55,11 +55,11 @@ def _assert_stats_unavailable(self, stats): @staticmethod def override_response_body_with_live_status(response): - response.http_response.text = lambda encoding=None: SERVICE_LIVE_RESP_BODY + response.http_response.body = lambda: SERVICE_LIVE_RESP_BODY @staticmethod def override_response_body_with_unavailable_status(response): - response.http_response.text = lambda encoding=None: SERVICE_UNAVAILABLE_RESP_BODY + response.http_response.body = lambda: SERVICE_UNAVAILABLE_RESP_BODY # --Test cases per service --------------------------------------- @GlobalResourceGroupPreparer()