diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py
index e3a659f9c867..c86e738c7ef7 100644
--- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py
+++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py
@@ -31,6 +31,7 @@
validate_and_format_range_headers)
from ._shared.response_handlers import return_response_headers, process_storage_error, return_headers_and_deserialized
from ._generated import AzureBlobStorage
+from ._generated import models as generated_models
from ._generated.models import ( # pylint: disable=unused-import
DeleteSnapshotsOptionType,
BlobHTTPHeaders,
@@ -175,6 +176,8 @@ def __init__(
self._query_str, credential = self._format_query_string(sas_token, credential, snapshot=self.snapshot)
super(BlobClient, self).__init__(parsed_url, service='blob', credential=credential, **kwargs)
self._client = AzureBlobStorage(self.url, pipeline=self._pipeline)
+ if not self._msrest_xml:
+ self._custom_xml_deserializer(generated_models)
default_api_version = self._client._config.version # pylint: disable=protected-access
self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access
diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py
index d277a094921a..cd0f4f3137ac 100644
--- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py
+++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py
@@ -28,6 +28,7 @@
from ._shared.response_handlers import return_response_headers, process_storage_error, \
parse_to_internal_user_delegation_key
from ._generated import AzureBlobStorage
+from ._generated import models as generated_models
from ._generated.models import StorageServiceProperties, KeyInfo
from ._container_client import ContainerClient
from ._blob_client import BlobClient
@@ -134,6 +135,8 @@ def __init__(
self._query_str, credential = self._format_query_string(sas_token, credential)
super(BlobServiceClient, self).__init__(parsed_url, service='blob', credential=credential, **kwargs)
self._client = AzureBlobStorage(self.url, pipeline=self._pipeline)
+ if not self._msrest_xml:
+ self._custom_xml_deserializer(generated_models)
default_api_version = self._client._config.version # pylint: disable=protected-access
self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access
@@ -676,7 +679,7 @@ def get_container_client(self, container):
credential=self.credential, api_version=self.api_version, _configuration=self._config,
_pipeline=_pipeline, _location_mode=self._location_mode, _hosts=self._hosts,
require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key,
- key_resolver_function=self.key_resolver_function)
+ key_resolver_function=self.key_resolver_function, msrest_xml=self._msrest_xml)
def get_blob_client(
self, container, # type: Union[ContainerProperties, str]
@@ -729,4 +732,4 @@ def get_blob_client(
credential=self.credential, api_version=self.api_version, _configuration=self._config,
_pipeline=_pipeline, _location_mode=self._location_mode, _hosts=self._hosts,
require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key,
- key_resolver_function=self.key_resolver_function)
+ key_resolver_function=self.key_resolver_function, msrest_xml=self._msrest_xml)
diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py
index b63556ba61bc..d8ebad17a640 100644
--- a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py
+++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py
@@ -34,6 +34,7 @@
return_response_headers,
return_headers_and_deserialized)
from ._generated import AzureBlobStorage
+from ._generated import models as generated_models
from ._generated.models import SignedIdentifier
from ._deserialize import deserialize_container_properties
from ._serialize import get_modify_conditions, get_container_cpk_scope_info, get_api_version, get_access_conditions
@@ -156,6 +157,8 @@ def __init__(
self._query_str, credential = self._format_query_string(sas_token, credential)
super(ContainerClient, self).__init__(parsed_url, service='blob', credential=credential, **kwargs)
self._client = AzureBlobStorage(self.url, pipeline=self._pipeline)
+ if not self._msrest_xml:
+ self._custom_xml_deserializer(generated_models)
default_api_version = self._client._config.version # pylint: disable=protected-access
self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access
@@ -769,7 +772,41 @@ def list_blobs(self, name_starts_with=None, include=None, **kwargs):
timeout=timeout,
**kwargs)
return ItemPaged(
- command, prefix=name_starts_with, results_per_page=results_per_page,
+ command,
+ prefix=name_starts_with,
+ results_per_page=results_per_page,
+ select=None,
+ deserializer=self._client._deserialize, # pylint: disable=protected-access
+ page_iterator_class=BlobPropertiesPaged)
+
+ @distributed_trace
+ def list_blob_names(self, **kwargs):
+ # type: (**Any) -> ItemPaged[str]
+ """Returns a generator to list the names of blobs under the specified container.
+ The generator will lazily follow the continuation tokens returned by
+ the service.
+
+ :keyword str name_starts_with:
+ Filters the results to return only blobs whose names
+ begin with the specified prefix.
+ :keyword int timeout:
+ The timeout parameter is expressed in seconds.
+ :returns: An iterable (auto-paging) response of blob names as strings.
+ :rtype: ~azure.core.paging.ItemPaged[str]
+ """
+ name_starts_with = kwargs.pop('name_starts_with', None)
+ results_per_page = kwargs.pop('results_per_page', None)
+ timeout = kwargs.pop('timeout', None)
+ command = functools.partial(
+ self._client.container.list_blob_flat_segment,
+ timeout=timeout,
+ **kwargs)
+ return ItemPaged(
+ command,
+ prefix=name_starts_with,
+ results_per_page=results_per_page,
+ select=["name"],
+ deserializer=self._client._deserialize, # pylint: disable=protected-access
page_iterator_class=BlobPropertiesPaged)
@distributed_trace
@@ -816,6 +853,8 @@ def walk_blobs(
command,
prefix=name_starts_with,
results_per_page=results_per_page,
+ select=None,
+ deserializer=self._client._deserialize, # pylint: disable=protected-access
delimiter=delimiter)
@distributed_trace
@@ -1548,4 +1587,4 @@ def get_blob_client(
credential=self.credential, api_version=self.api_version, _configuration=self._config,
_pipeline=_pipeline, _location_mode=self._location_mode, _hosts=self._hosts,
require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key,
- key_resolver_function=self.key_resolver_function)
+ key_resolver_function=self.key_resolver_function, msrest_xml=self._msrest_xml)
diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py
index bec837429209..f4d4e68cdaa1 100644
--- a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py
+++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py
@@ -1452,12 +1452,12 @@ async def list_blob_flat_segment(
response_headers['x-ms-request-id']=self._deserialize('str', response.headers.get('x-ms-request-id'))
response_headers['x-ms-version']=self._deserialize('str', response.headers.get('x-ms-version'))
response_headers['Date']=self._deserialize('rfc-1123', response.headers.get('Date'))
- deserialized = self._deserialize('ListBlobsFlatSegmentResponse', pipeline_response)
+ #deserialized = self._deserialize('ListBlobsFlatSegmentResponse', pipeline_response)
if cls:
- return cls(pipeline_response, deserialized, response_headers)
+ return cls(pipeline_response, None, response_headers)
- return deserialized
+ return None
list_blob_flat_segment.metadata = {'url': '/{containerName}'} # type: ignore
async def list_blob_hierarchy_segment(
@@ -1564,12 +1564,12 @@ async def list_blob_hierarchy_segment(
response_headers['x-ms-request-id']=self._deserialize('str', response.headers.get('x-ms-request-id'))
response_headers['x-ms-version']=self._deserialize('str', response.headers.get('x-ms-version'))
response_headers['Date']=self._deserialize('rfc-1123', response.headers.get('Date'))
- deserialized = self._deserialize('ListBlobsHierarchySegmentResponse', pipeline_response)
+ # deserialized = self._deserialize('ListBlobsHierarchySegmentResponse', pipeline_response)
if cls:
- return cls(pipeline_response, deserialized, response_headers)
+ return cls(pipeline_response, None, response_headers)
- return deserialized
+ return None
list_blob_hierarchy_segment.metadata = {'url': '/{containerName}'} # type: ignore
async def get_account_info(
diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py
index f01bbc4393fe..018c93984bf2 100644
--- a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py
+++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py
@@ -1471,12 +1471,12 @@ def list_blob_flat_segment(
response_headers['x-ms-request-id']=self._deserialize('str', response.headers.get('x-ms-request-id'))
response_headers['x-ms-version']=self._deserialize('str', response.headers.get('x-ms-version'))
response_headers['Date']=self._deserialize('rfc-1123', response.headers.get('Date'))
- deserialized = self._deserialize('ListBlobsFlatSegmentResponse', pipeline_response)
+ #deserialized = self._deserialize('ListBlobsFlatSegmentResponse', pipeline_response)
if cls:
- return cls(pipeline_response, deserialized, response_headers)
+ return cls(pipeline_response, None, response_headers)
- return deserialized
+ return None # deserialized
list_blob_flat_segment.metadata = {'url': '/{containerName}'} # type: ignore
def list_blob_hierarchy_segment(
@@ -1584,12 +1584,12 @@ def list_blob_hierarchy_segment(
response_headers['x-ms-request-id']=self._deserialize('str', response.headers.get('x-ms-request-id'))
response_headers['x-ms-version']=self._deserialize('str', response.headers.get('x-ms-version'))
response_headers['Date']=self._deserialize('rfc-1123', response.headers.get('Date'))
- deserialized = self._deserialize('ListBlobsHierarchySegmentResponse', pipeline_response)
+ # deserialized = self._deserialize('ListBlobsHierarchySegmentResponse', pipeline_response)
if cls:
- return cls(pipeline_response, deserialized, response_headers)
+ return cls(pipeline_response, None, response_headers)
- return deserialized
+ return None # deserialized
list_blob_hierarchy_segment.metadata = {'url': '/{containerName}'} # type: ignore
def get_account_info(
diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py
index 309d37bd9583..6b4e1e6bd6b4 100644
--- a/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py
+++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py
@@ -7,14 +7,70 @@
from azure.core.paging import PageIterator, ItemPaged
from azure.core.exceptions import HttpResponseError
+
from ._deserialize import get_blob_properties_from_generated_code, parse_tags
-from ._generated.models import BlobItemInternal, BlobPrefix as GenBlobPrefix, FilterBlobItem
+from ._generated.models import FilterBlobItem
from ._models import BlobProperties, FilteredBlob
from ._shared.models import DictMixin
+from ._shared.xml_deserialization import unpack_xml_content
from ._shared.response_handlers import return_context_and_deserialized, process_storage_error
-class BlobPropertiesPaged(PageIterator):
+def deserialize_list_result(pipeline_response, *_):
+ payload = unpack_xml_content(pipeline_response.http_response)
+ location = pipeline_response.http_response.location_mode
+ return location, payload
+
+
+def load_xml_string(element, name):
+ node = element.find(name)
+ if node is None or not node.text:
+ return None
+ return node.text
+
+
+def load_xml_int(element, name):
+ node = element.find(name)
+ if node is None or not node.text:
+ return None
+ return int(node.text)
+
+
+def load_xml_bool(element, name):
+ node = load_xml_string(element, name)
+ if node and node.lower() == 'true':
+ return True
+ return False
+
+
+def load_single_node(element, name):
+ return element.find(name)
+
+
+def load_many_nodes(element, name, wrapper=None):
+ if wrapper:
+ element = load_single_node(element, wrapper)
+ return list(element.findall(name))
+
+
+def blob_properties_from_xml(element, select, deserializer):
+ if not select:
+ generated = deserializer.deserialize_data(element, 'BlobItemInternal')
+ return get_blob_properties_from_generated_code(generated)
+ blob = BlobProperties()
+ if 'name' in select:
+ blob.name = load_xml_string(element, 'Name')
+ if 'deleted' in select:
+ blob.deleted = load_xml_bool(element, 'Deleted')
+ if 'snapshot' in select:
+ blob.snapshot = load_xml_string(element, 'Snapshot')
+ if 'version' in select:
+ blob.version_id = load_xml_string(element, 'VersionId')
+ blob.is_current_version = load_xml_bool(element, 'IsCurrentVersion')
+ return blob
+
+
+class BlobPropertiesPaged(PageIterator): # pylint: disable=too-many-instance-attributes
"""An Iterable of Blob properties.
:ivar str service_endpoint: The service URL.
@@ -49,6 +105,8 @@ def __init__(
container=None,
prefix=None,
results_per_page=None,
+ select=None,
+ deserializer=None,
continuation_token=None,
delimiter=None,
location_mode=None):
@@ -58,10 +116,12 @@ def __init__(
continuation_token=continuation_token or ""
)
self._command = command
+ self._deserializer = deserializer
self.service_endpoint = None
self.prefix = prefix
self.marker = None
self.results_per_page = results_per_page
+ self.select = select
self.container = container
self.delimiter = delimiter
self.current_page = None
@@ -73,30 +133,29 @@ def _get_next_cb(self, continuation_token):
prefix=self.prefix,
marker=continuation_token or None,
maxresults=self.results_per_page,
- cls=return_context_and_deserialized,
+ cls=deserialize_list_result,
use_location=self.location_mode)
except HttpResponseError as error:
process_storage_error(error)
def _extract_data_cb(self, get_next_return):
self.location_mode, self._response = get_next_return
- self.service_endpoint = self._response.service_endpoint
- self.prefix = self._response.prefix
- self.marker = self._response.marker
- self.results_per_page = self._response.max_results
- self.container = self._response.container_name
- self.current_page = [self._build_item(item) for item in self._response.segment.blob_items]
+ self.service_endpoint = self._response.get('ServiceEndpoint')
+ self.prefix = load_xml_string(self._response, 'Prefix')
+ self.marker = load_xml_string(self._response, 'Marker')
+ self.results_per_page = load_xml_int(self._response, 'MaxResults')
+ self.container = self._response.get('ContainerName')
- return self._response.next_marker or None, self.current_page
+ blobs = load_many_nodes(self._response, 'Blob', wrapper='Blobs')
+ self.current_page = [self._build_item(blob) for blob in blobs]
+
+ next_marker = load_xml_string(self._response, 'NextMarker')
+ return next_marker or None, self.current_page
def _build_item(self, item):
- if isinstance(item, BlobProperties):
- return item
- if isinstance(item, BlobItemInternal):
- blob = get_blob_properties_from_generated_code(item) # pylint: disable=protected-access
- blob.container = self.container
- return blob
- return item
+ blob = blob_properties_from_xml(item, self.select, self._deserializer)
+ blob.container = self.container
+ return blob
class BlobPrefixPaged(BlobPropertiesPaged):
@@ -106,22 +165,26 @@ def __init__(self, *args, **kwargs):
def _extract_data_cb(self, get_next_return):
continuation_token, _ = super(BlobPrefixPaged, self)._extract_data_cb(get_next_return)
- self.current_page = self._response.segment.blob_prefixes + self._response.segment.blob_items
- self.current_page = [self._build_item(item) for item in self.current_page]
- self.delimiter = self._response.delimiter
+
+ blob_prefixes = load_many_nodes(self._response, 'BlobPrefix', wrapper='Blobs')
+ blob_prefixes = [self._build_prefix(blob) for blob in blob_prefixes]
+
+ self.current_page = blob_prefixes + self.current_page
+ self.delimiter = load_xml_string(self._response, 'Delimiter')
return continuation_token, self.current_page
- def _build_item(self, item):
- item = super(BlobPrefixPaged, self)._build_item(item)
- if isinstance(item, GenBlobPrefix):
- return BlobPrefix(
- self._command,
- container=self.container,
- prefix=item.name,
- results_per_page=self.results_per_page,
- location_mode=self.location_mode)
- return item
+ def _build_prefix(self, item):
+ return BlobPrefix(
+ self._command,
+ container=self.container,
+ prefix=load_xml_string(item, 'Name'),
+ results_per_page=self.results_per_page,
+ location_mode=self.location_mode,
+ select=self.select,
+ deserializer=self._deserializer,
+ delimiter=self.delimiter
+ )
class BlobPrefix(ItemPaged, DictMixin):
diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py
index a2efa2170228..955a8073f5db 100644
--- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py
+++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py
@@ -35,6 +35,7 @@
AzureSasCredentialPolicy
)
+from .xml_deserialization import Deserializer
from .constants import STORAGE_OAUTH_SCOPE, SERVICE_HOST_BASE, CONNECTION_TIMEOUT, READ_TIMEOUT
from .models import LocationMode
from .authentication import SharedKeyCredentialPolicy
@@ -74,6 +75,7 @@ def __init__(
# type: (...) -> None
self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY)
self._hosts = kwargs.get("_hosts")
+ self._msrest_xml = kwargs.get('msrest_xml', False)
self.scheme = parsed_url.scheme
if service not in ["blob", "queue", "file-share", "dfs"]:
@@ -199,6 +201,20 @@ def api_version(self):
"""
return self._client._config.version # pylint: disable=protected-access
+ def _custom_xml_deserializer(self, generated_models):
+ """Reset the deserializer on the generated client to be Storage implementation"""
+ # pylint: disable=protected-access
+ client_models = {k: v for k, v in generated_models.__dict__.items() if isinstance(v, type)}
+ custom_deserialize = Deserializer(client_models)
+ self._client._deserialize = custom_deserialize
+ self._client.service._deserialize = custom_deserialize
+ self._client.container._deserialize = custom_deserialize
+ self._client.directory._deserialize = custom_deserialize
+ self._client.blob._deserialize = custom_deserialize
+ self._client.page_blob._deserialize = custom_deserialize
+ self._client.append_blob._deserialize = custom_deserialize
+ self._client.block_blob._deserialize = custom_deserialize
+
def _format_query_string(self, sas_token, credential, snapshot=None, share_snapshot=None):
query_str = "?"
if snapshot:
@@ -237,14 +253,13 @@ def _create_pipeline(self, credential, **kwargs):
config.transport = RequestsTransport(**kwargs)
policies = [
QueueMessagePolicy(),
+ config.headers_policy,
config.proxy_policy,
config.user_agent_policy,
StorageContentValidation(),
- ContentDecodePolicy(response_encoding="utf-8"),
RedirectPolicy(**kwargs),
StorageHosts(hosts=self._hosts, **kwargs),
config.retry_policy,
- config.headers_policy,
StorageRequestHook(**kwargs),
self._credential_policy,
config.logging_policy,
@@ -252,6 +267,8 @@ def _create_pipeline(self, credential, **kwargs):
DistributedTracingPolicy(**kwargs),
HttpLoggingPolicy(**kwargs)
]
+ if self._msrest_xml:
+ policies.insert(5, ContentDecodePolicy(response_encoding="utf-8"))
if kwargs.get("_additional_pipeline_policies"):
policies = policies + kwargs.get("_additional_pipeline_policies")
return config, Pipeline(config.transport, policies=policies)
diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py
index 3e619c90fd71..5deca436299b 100644
--- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py
+++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py
@@ -15,12 +15,12 @@
from azure.core.async_paging import AsyncList
from azure.core.exceptions import HttpResponseError
from azure.core.pipeline.policies import (
- ContentDecodePolicy,
AsyncBearerTokenCredentialPolicy,
AsyncRedirectPolicy,
DistributedTracingPolicy,
HttpLoggingPolicy,
AzureSasCredentialPolicy,
+ ContentDecodePolicy
)
from azure.core.pipeline.transport import AsyncHttpTransport
@@ -97,7 +97,6 @@ def _create_pipeline(self, credential, **kwargs):
StorageContentValidation(),
StorageRequestHook(**kwargs),
self._credential_policy,
- ContentDecodePolicy(response_encoding="utf-8"),
AsyncRedirectPolicy(**kwargs),
StorageHosts(hosts=self._hosts, **kwargs), # type: ignore
config.retry_policy,
@@ -106,6 +105,8 @@ def _create_pipeline(self, credential, **kwargs):
DistributedTracingPolicy(**kwargs),
HttpLoggingPolicy(**kwargs),
]
+ if self._msrest_xml:
+ policies.insert(5, ContentDecodePolicy(response_encoding="utf-8"))
if kwargs.get("_additional_pipeline_policies"):
policies = policies + kwargs.get("_additional_pipeline_policies")
return config, AsyncPipeline(config.transport, policies=policies)
diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py
new file mode 100644
index 000000000000..20f53f028354
--- /dev/null
+++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py
@@ -0,0 +1,587 @@
+# --------------------------------------------------------------------------
+#
+# Copyright (c) Microsoft Corporation. All rights reserved.
+#
+# The MIT License (MIT)
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the ""Software""), to
+# deal in the Software without restriction, including without limitation the
+# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+# sell copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+# IN THE SOFTWARE.
+#
+# --------------------------------------------------------------------------
+
+from base64 import b64decode
+import datetime
+import decimal
+import email
+from enum import Enum
+import logging
+import re
+import xml.etree.ElementTree as ET
+
+import isodate
+from msrest.exceptions import DeserializationError, raise_with_traceback
+from msrest.serialization import (
+ TZ_UTC,
+ _FixedOffset
+)
+
+from azure.core.exceptions import DecodeError
+
+
+try:
+ basestring # pylint: disable=pointless-statement
+ unicode_str = unicode # type: ignore
+except NameError:
+ basestring = str # type: ignore
+ unicode_str = str # type: ignore
+
+_LOGGER = logging.getLogger(__name__)
+_valid_date = re.compile(
+ r'\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}'
+ r'\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?')
+
+try:
+ _long_type = long # type: ignore
+except NameError:
+ _long_type = int
+
+
+def unpack_xml_content(response_data, **kwargs):
+ """Extract the correct structure for deserialization.
+
+ If raw_data is a PipelineResponse, try to extract the result of RawDeserializer.
+ if we can't, raise. Your Pipeline should have a RawDeserializer.
+
+ If not a pipeline response and raw_data is bytes or string, use content-type
+ to decode it. If no content-type, try JSON.
+
+ If raw_data is something else, bypass all logic and return it directly.
+
+ :param raw_data: Data to be processed.
+ :param content_type: How to parse if raw_data is a string/bytes.
+ :raises UnicodeDecodeError: If bytes is not UTF8
+ """
+ try:
+ return ET.fromstring(response_data.body()) # nosec
+ except ET.ParseError:
+ _LOGGER.critical("Response body invalid XML")
+ raise_with_traceback(DecodeError, message="XML is invalid", response=response_data, **kwargs)
+
+
+def deserialize_bytearray(attr, *_):
+ """Deserialize string into bytearray.
+
+ :param str attr: response string to be deserialized.
+ :rtype: bytearray
+ :raises: TypeError if string format invalid.
+ """
+ return bytearray(b64decode(attr))
+
+
+def deserialize_base64(attr, *_):
+ """Deserialize base64 encoded string into string.
+
+ :param str attr: response string to be deserialized.
+ :rtype: bytearray
+ :raises: TypeError if string format invalid.
+ """
+ padding = '=' * (3 - (len(attr) + 3) % 4)
+ attr = attr + padding
+ encoded = attr.replace('-', '+').replace('_', '/')
+ return b64decode(encoded)
+
+
+def deserialize_decimal(attr, *_):
+ """Deserialize string into Decimal object.
+
+ :param str attr: response string to be deserialized.
+ :rtype: Decimal
+ :raises: DeserializationError if string format invalid.
+ """
+ try:
+ return decimal.Decimal(attr)
+ except decimal.DecimalException as err:
+ msg = "Invalid decimal {}".format(attr)
+ raise_with_traceback(DeserializationError, msg, err)
+
+
+def deserialize_bool(attr, *_):
+ """Deserialize string into bool.
+
+ :param str attr: response string to be deserialized.
+ :rtype: bool
+ :raises: TypeError if string format is not valid.
+ """
+ if attr in [True, False, 1, 0]:
+ return bool(attr)
+ if isinstance(attr, basestring):
+ if attr.lower() in ['true', '1']:
+ return True
+ if attr.lower() in ['false', '0']:
+ return False
+ raise TypeError("Invalid boolean value: {}".format(attr))
+
+
+def deserialize_int(attr, *_):
+ """Deserialize string into int.
+
+ :param str attr: response string to be deserialized.
+ :rtype: int
+ :raises: ValueError or TypeError if string format invalid.
+ """
+ return int(attr)
+
+
+def deserialize_float(attr, *_):
+ """Deserialize string into float.
+
+ :param str attr: response string to be deserialized.
+ :rtype: float
+ :raises: ValueError if string format invalid.
+ """
+ return float(attr)
+
+
+def deserialize_long(attr, *_):
+ """Deserialize string into long (Py2) or int (Py3).
+
+ :param str attr: response string to be deserialized.
+ :rtype: long or int
+ :raises: ValueError if string format invalid.
+ """
+ return _long_type(attr)
+
+
+def deserialize_duration(attr, *_):
+ """Deserialize ISO-8601 formatted string into TimeDelta object.
+
+ :param str attr: response string to be deserialized.
+ :rtype: TimeDelta
+ :raises: DeserializationError if string format invalid.
+ """
+ try:
+ duration = isodate.parse_duration(attr)
+ except(ValueError, OverflowError, AttributeError) as err:
+ msg = "Cannot deserialize duration object."
+ raise_with_traceback(DeserializationError, msg, err)
+ else:
+ return duration
+
+
+def deserialize_date(attr, *_):
+ """Deserialize ISO-8601 formatted string into Date object.
+
+ :param str attr: response string to be deserialized.
+ :rtype: Date
+ :raises: DeserializationError if string format invalid.
+ """
+ if re.search(r"[^\W\d_]", attr, re.I + re.U):
+ raise DeserializationError("Date must have only digits and -. Received: %s" % attr)
+ # This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception.
+ return isodate.parse_date(attr, defaultmonth=None, defaultday=None)
+
+
+def deserialize_time(attr, *_):
+ """Deserialize ISO-8601 formatted string into time object.
+
+ :param str attr: response string to be deserialized.
+ :rtype: datetime.time
+ :raises: DeserializationError if string format invalid.
+ """
+ if re.search(r"[^\W\d_]", attr, re.I + re.U):
+ raise DeserializationError("Date must have only digits and -. Received: %s" % attr)
+ return isodate.parse_time(attr)
+
+
+def deserialize_rfc(attr, *_):
+ """Deserialize RFC-1123 formatted string into Datetime object.
+
+ :param str attr: response string to be deserialized.
+ :rtype: Datetime
+ :raises: DeserializationError if string format invalid.
+ """
+ try:
+ parsed_date = email.utils.parsedate_tz(attr)
+ date_obj = datetime.datetime(
+ *parsed_date[:6],
+ tzinfo=_FixedOffset(datetime.timedelta(minutes=(parsed_date[9] or 0)/60))
+ )
+ if not date_obj.tzinfo:
+ date_obj = date_obj.astimezone(tz=TZ_UTC)
+ except ValueError as err:
+ msg = "Cannot deserialize to rfc datetime object."
+ raise_with_traceback(DeserializationError, msg, err)
+ else:
+ return date_obj
+
+
+def deserialize_iso(attr, *_):
+ """Deserialize ISO-8601 formatted string into Datetime object.
+
+ :param str attr: response string to be deserialized.
+ :rtype: Datetime
+ :raises: DeserializationError if string format invalid.
+ """
+ try:
+ attr = attr.upper()
+ match = _valid_date.match(attr)
+ if not match:
+ raise ValueError("Invalid datetime string: " + attr)
+
+ check_decimal = attr.split('.')
+ if len(check_decimal) > 1:
+ decimal_str = ""
+ for digit in check_decimal[1]:
+ if digit.isdigit():
+ decimal_str += digit
+ else:
+ break
+ if len(decimal_str) > 6:
+ attr = attr.replace(decimal_str, decimal_str[0:6])
+
+ date_obj = isodate.parse_datetime(attr)
+ test_utc = date_obj.utctimetuple()
+ if test_utc.tm_year > 9999 or test_utc.tm_year < 1:
+ raise OverflowError("Hit max or min date")
+ except(ValueError, OverflowError, AttributeError) as err:
+ msg = "Cannot deserialize datetime object."
+ raise_with_traceback(DeserializationError, msg, err)
+ else:
+ return date_obj
+
+
+def deserialize_object(attr, *_):
+ """Deserialize a generic object.
+ This will be handled as a dictionary.
+
+ :param dict attr: Dictionary to be deserialized.
+ :rtype: dict
+ :raises: TypeError if non-builtin datatype encountered.
+ """
+ # Do no recurse on XML, just return the tree as-is
+ # TODO: This probably needs work
+ return attr
+
+
+def deserialize_unix(attr, *_):
+ """Serialize Datetime object into IntTime format.
+ This is represented as seconds.
+
+ :param int attr: Object to be serialized.
+ :rtype: Datetime
+ :raises: DeserializationError if format invalid
+ """
+ try:
+ date_obj = datetime.datetime.fromtimestamp(int(attr), TZ_UTC)
+ except ValueError as err:
+ msg = "Cannot deserialize to unix datetime object."
+ raise_with_traceback(DeserializationError, msg, err)
+ else:
+ return date_obj
+
+
+def deserialize_unicode(data, *_):
+ """Preserve unicode objects in Python 2, otherwise return data
+ as a string.
+
+ :param str data: response string to be deserialized.
+ :rtype: str or unicode
+ """
+ if data is None:
+ return ""
+ # We might be here because we have an enum modeled as string,
+ # and we try to deserialize a partial dict with enum inside
+ if isinstance(data, Enum):
+ return data
+
+ # Consider this is real string
+ try:
+ if isinstance(data, unicode):
+ return data
+ except NameError:
+ return str(data)
+ else:
+ return str(data)
+
+
+def deserialize_enum(data, enum_obj):
+ """Deserialize string into enum object.
+
+ If the string is not a valid enum value it will be returned as-is
+ and a warning will be logged.
+
+ :param str data: Response string to be deserialized. If this value is
+ None or invalid it will be returned as-is.
+ :param Enum enum_obj: Enum object to deserialize to.
+ :rtype: Enum
+ """
+ if isinstance(data, enum_obj) or data is None:
+ return data
+ if isinstance(data, Enum):
+ data = data.value
+ if isinstance(data, int):
+ # Workaround. We might consider remove it in the future.
+ # https://github.com/Azure/azure-rest-api-specs/issues/141
+ try:
+ return list(enum_obj.__members__.values())[data]
+ except IndexError:
+ error = "{!r} is not a valid index for enum {!r}"
+ raise DeserializationError(error.format(data, enum_obj))
+ try:
+ return enum_obj(str(data))
+ except ValueError:
+ for enum_value in enum_obj:
+ if enum_value.value.lower() == str(data).lower():
+ return enum_value
+ # We don't fail anymore for unknown value, we deserialize as a string
+ _LOGGER.warning("Deserializer is not able to find %s as valid enum in %s", data, enum_obj)
+ return deserialize_unicode(data)
+
+
+def instantiate_model(response, attrs, additional_properties=None):
+ """Instantiate a response model passing in deserialized args.
+
+ :param response: The response model class.
+ :param d_attrs: The deserialized response attributes.
+ """
+ try:
+ readonly = [k for k, v in response._validation.items() if v.get('readonly')] # pylint:disable=protected-access
+ const = [k for k, v in response._validation.items() if v.get('constant')] # pylint:disable=protected-access
+ kwargs = {k: v for k, v in attrs.items() if k not in readonly + const}
+ response_obj = response(**kwargs)
+ for attr in readonly:
+ setattr(response_obj, attr, attrs.get(attr))
+ if additional_properties:
+ response_obj.additional_properties = additional_properties
+ return response_obj
+ except Exception as err:
+ msg = "Unable to deserialize {} into model {}. ".format(
+ kwargs, response)
+ raise DeserializationError(msg + str(err))
+
+
+def multi_xml_key_extractor(attr_desc, data, subtype):
+ xml_desc = attr_desc.get('xml', {})
+ xml_name = xml_desc.get('name', attr_desc['key'])
+ is_wrapped = xml_desc.get("wrapped", False)
+ subtype_xml_map = getattr(subtype, "_xml_map", {})
+ if is_wrapped:
+ items_name = xml_name
+ elif subtype:
+ items_name = subtype_xml_map.get('name', xml_name)
+ else:
+ items_name = xml_desc.get("itemsName", xml_name)
+ children = data.findall(items_name)
+ if is_wrapped:
+ if len(children) == 0:
+ return None
+ return list(children[0])
+ return children
+
+def xml_key_extractor(attr_desc, data, subtype):
+ xml_desc = attr_desc.get('xml', {})
+ xml_name = xml_desc.get('name', attr_desc['key'])
+
+ # If it's an attribute, that's simple
+ if xml_desc.get("attr", False):
+ return data.get(xml_name)
+
+ # If it's x-ms-text, that's simple too
+ if xml_desc.get("text", False):
+ return data.text
+
+ subtype_xml_map = getattr(subtype, "_xml_map", {})
+ xml_name = subtype_xml_map.get('name', xml_name)
+ return data.find(xml_name)
+
+
+class Deserializer(object):
+ """Response object model deserializer.
+
+ :param dict classes: Class type dictionary for deserializing complex types.
+ """
+ def __init__(self, classes=None):
+ self.deserialize_type = {
+ 'str': deserialize_unicode,
+ 'int': deserialize_int,
+ 'bool': deserialize_bool,
+ 'float': deserialize_float,
+ 'iso-8601': deserialize_iso,
+ 'rfc-1123': deserialize_rfc,
+ 'unix-time': deserialize_unix,
+ 'duration': deserialize_duration,
+ 'date': deserialize_date,
+ 'time': deserialize_time,
+ 'decimal': deserialize_decimal,
+ 'long': deserialize_long,
+ 'bytearray': deserialize_bytearray,
+ 'base64': deserialize_base64,
+ 'object': deserialize_object,
+ '[]': self.deserialize_iter,
+ '{}': self.deserialize_dict
+ }
+
+ self.dependencies = dict(classes) if classes else {}
+
+ def __call__(self, target_obj, response_data, **kwargs):
+ """Call the deserializer to process a REST response.
+
+ :param str target_obj: Target data type to deserialize to.
+ :param requests.Response response_data: REST response object.
+ :param str content_type: Swagger "produces" if available.
+ :raises: DeserializationError if deserialization fails.
+ :return: Deserialized object.
+ """
+ try:
+ # First, unpack the response if we have one.
+ response_data = unpack_xml_content(response_data.http_response, **kwargs)
+ except AttributeError:
+ pass
+ if response_data is None:
+ # No data. Moving on.
+ return None
+ #return self._deserialize(target_obj, response_data)
+ return self.deserialize_data(response_data, target_obj)
+
+ def failsafe_deserialize(self, target_obj, data, content_type=None):
+ """Ignores any errors encountered in deserialization,
+ and falls back to not deserializing the object. Recommended
+ for use in error deserialization, as we want to return the
+ HttpResponseError to users, and not have them deal with
+ a deserialization error.
+
+ :param str target_obj: The target object type to deserialize to.
+ :param str/dict data: The response data to deseralize.
+ :param str content_type: Swagger "produces" if available.
+ """
+ try:
+ return self(target_obj, data, content_type=content_type)
+ except: # pylint: disable=bare-except
+ _LOGGER.warning(
+ "Ran into a deserialization error. Ignoring since this is failsafe deserialization",
+ exc_info=True
+ )
+ return None
+
+ def _deserialize(self, target_obj, data):
+ """Call the deserializer on a model.
+
+ Data needs to be already deserialized as JSON or XML ElementTree
+
+ :param str target_obj: Target data type to deserialize to.
+ :param object data: Object to deserialize.
+ :raises: DeserializationError if deserialization fails.
+ :return: Deserialized object.
+ """
+ try:
+ model_type = self.dependencies[target_obj]
+ if issubclass(model_type, Enum):
+ return deserialize_enum(data.text, model_type)
+ except KeyError:
+ return self.deserialize_data(data, target_obj)
+
+ if data is None:
+ return data
+ try:
+ attributes = model_type._attribute_map # pylint:disable=protected-access
+ d_attrs = {}
+ include_extra_props = False
+ for attr, attr_desc in attributes.items():
+ # Check empty string. If it's not empty, someone has a real "additionalProperties"...
+ if attr == "additional_properties" and attr_desc["key"] == '':
+ include_extra_props = True
+ continue
+ attr_type = attr_desc["type"]
+ try:
+ # TODO: Validate this subtype logic
+ subtype = self.dependencies[attr_type.strip('[]{}')]
+ except KeyError:
+ subtype = None
+ if attr_type[0] == '[':
+ raw_value = multi_xml_key_extractor(attr_desc, data, subtype)
+ else:
+ raw_value = xml_key_extractor(attr_desc, data, subtype)
+ value = self.deserialize_data(raw_value, attr_type)
+ d_attrs[attr] = value
+ except (AttributeError, TypeError, KeyError) as err:
+ msg = "Unable to deserialize to object: " + str(target_obj)
+ raise_with_traceback(DeserializationError, msg, err)
+ else:
+ if include_extra_props:
+ extra = {el.tag: el.text for el in data if el.tag not in d_attrs}
+ return instantiate_model(model_type, d_attrs, extra)
+ return instantiate_model(model_type, d_attrs)
+
+ def deserialize_data(self, data, data_type):
+ """Process data for deserialization according to data type.
+
+ :param str data: The response string to be deserialized.
+ :param str data_type: The type to deserialize to.
+ :raises: DeserializationError if deserialization fails.
+ :return: Deserialized object.
+ """
+ if not data_type or data is None:
+ return None
+ try:
+ xml_data = data.text
+ except AttributeError:
+ xml_data = data
+
+ try:
+ basic_deserialize = self.deserialize_type[data_type]
+ if not xml_data and data_type != 'str':
+ return None
+ return basic_deserialize(xml_data, data_type)
+ except KeyError:
+ pass
+ except (ValueError, TypeError, AttributeError) as err:
+ msg = "Unable to deserialize response data."
+ msg += " Data: {}, {}".format(data, data_type)
+ raise_with_traceback(DeserializationError, msg, err)
+
+ try:
+ iter_type = data_type[0] + data_type[-1]
+ if iter_type in self.deserialize_type:
+ return self.deserialize_type[iter_type](data, data_type[1:-1])
+ except (ValueError, TypeError, AttributeError) as err:
+ msg = "Unable to deserialize response data."
+ msg += " Data: {}, {}".format(data, data_type)
+ raise_with_traceback(DeserializationError, msg, err)
+ else:
+ return self._deserialize(data_type, data)
+
+ def deserialize_iter(self, attr, iter_type):
+ """Deserialize an iterable.
+
+ :param list attr: Iterable to be deserialized.
+ :param str iter_type: The type of object in the iterable.
+ :rtype: list
+ """
+ return [self.deserialize_data(a, iter_type) for a in list(attr)]
+
+ def deserialize_dict(self, attr, dict_type):
+ """Deserialize a dictionary.
+
+ :param dict/list attr: Dictionary to be deserialized. Also accepts
+ a list of key, value pairs.
+ :param str dict_type: The object type of the items in the dictionary.
+ :rtype: dict
+ """
+ # Transform value into {"Key": "value"}
+ attr = {el.tag: el.text for el in attr}
+ return {k: self.deserialize_data(v, dict_type) for k, v in attr.items()}
diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py
index d13de28f8711..ec5266183770 100644
--- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py
+++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py
@@ -21,6 +21,7 @@
from .._deserialize import get_page_ranges_result, parse_tags, deserialize_pipeline_response_into_cls
from .._serialize import get_modify_conditions, get_api_version, get_access_conditions
from .._generated.aio import AzureBlobStorage
+from .._generated import models as generated_models
from .._generated.models import CpkInfo
from .._deserialize import deserialize_blob_properties
from .._blob_client import BlobClient as BlobClientBase
@@ -120,6 +121,8 @@ def __init__(
credential=credential,
**kwargs)
self._client = AzureBlobStorage(url=self.url, pipeline=self._pipeline)
+ if not self._msrest_xml:
+ self._custom_xml_deserializer(generated_models)
default_api_version = self._client._config.version # pylint: disable=protected-access
self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access
diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_service_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_service_client_async.py
index d50661d8e2d7..3a2fd1eb4b14 100644
--- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_service_client_async.py
+++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_service_client_async.py
@@ -24,6 +24,7 @@
from .._shared.parser import _to_utc_datetime
from .._shared.response_handlers import parse_to_internal_user_delegation_key
from .._generated.aio import AzureBlobStorage
+from .._generated import models as generated_models
from .._generated.models import StorageServiceProperties, KeyInfo
from .._blob_service_client import BlobServiceClient as BlobServiceClientBase
from ._container_client_async import ContainerClient
@@ -118,6 +119,8 @@ def __init__(
credential=credential,
**kwargs)
self._client = AzureBlobStorage(url=self.url, pipeline=self._pipeline)
+ if not self._msrest_xml:
+ self._custom_xml_deserializer(generated_models)
default_api_version = self._client._config.version # pylint: disable=protected-access
self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access
@@ -619,7 +622,7 @@ def get_container_client(self, container):
credential=self.credential, api_version=self.api_version, _configuration=self._config,
_pipeline=_pipeline, _location_mode=self._location_mode, _hosts=self._hosts,
require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key,
- key_resolver_function=self.key_resolver_function)
+ key_resolver_function=self.key_resolver_function, msrest_xml=self._msrest_xml)
def get_blob_client(
self, container, # type: Union[ContainerProperties, str]
@@ -674,4 +677,4 @@ def get_blob_client(
credential=self.credential, api_version=self.api_version, _configuration=self._config,
_pipeline=_pipeline, _location_mode=self._location_mode, _hosts=self._hosts,
require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key,
- key_resolver_function=self.key_resolver_function)
+ key_resolver_function=self.key_resolver_function, msrest_xml=self._msrest_xml)
diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py
index cd0164392ab6..9a919ec6c36b 100644
--- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py
+++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py
@@ -26,6 +26,7 @@
return_response_headers,
return_headers_and_deserialized)
from .._generated.aio import AzureBlobStorage
+from .._generated import models as generated_models
from .._generated.models import SignedIdentifier
from .._deserialize import deserialize_container_properties
from .._serialize import get_modify_conditions, get_container_cpk_scope_info, get_api_version, get_access_conditions
@@ -117,6 +118,8 @@ def __init__(
credential=credential,
**kwargs)
self._client = AzureBlobStorage(url=self.url, pipeline=self._pipeline)
+ if not self._msrest_xml:
+ self._custom_xml_deserializer(generated_models)
default_api_version = self._client._config.version # pylint: disable=protected-access
self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access
@@ -633,6 +636,39 @@ def list_blobs(self, name_starts_with=None, include=None, **kwargs):
command,
prefix=name_starts_with,
results_per_page=results_per_page,
+ select=None,
+ deserializer=self._client._deserialize, # pylint: disable=protected-access
+ page_iterator_class=BlobPropertiesPaged
+ )
+
+ @distributed_trace
+ def list_blob_names(self, **kwargs):
+ # type: (**Any) -> AsyncItemPaged[str]
+ """Returns a generator to list the names of blobs under the specified container.
+ The generator will lazily follow the continuation tokens returned by
+ the service.
+
+ :keyword str name_starts_with:
+ Filters the results to return only blobs whose names
+ begin with the specified prefix.
+ :keyword int timeout:
+ The timeout parameter is expressed in seconds.
+ :returns: An iterable (auto-paging) response of blob names as strings.
+ :rtype: ~azure.core.async_paging.AsyncItemPaged[str]
+ """
+ name_starts_with = kwargs.pop('name_starts_with', None)
+ results_per_page = kwargs.pop('results_per_page', None)
+ timeout = kwargs.pop('timeout', None)
+ command = functools.partial(
+ self._client.container.list_blob_flat_segment,
+ timeout=timeout,
+ **kwargs)
+ return AsyncItemPaged(
+ command,
+ prefix=name_starts_with,
+ results_per_page=results_per_page,
+ select=["name"],
+ deserializer=self._client._deserialize, # pylint: disable=protected-access
page_iterator_class=BlobPropertiesPaged
)
@@ -680,6 +716,8 @@ def walk_blobs(
command,
prefix=name_starts_with,
results_per_page=results_per_page,
+ select=None,
+ deserializer=self._client._deserialize, # pylint: disable=protected-access
delimiter=delimiter)
@distributed_trace_async
@@ -1206,4 +1244,4 @@ def get_blob_client(
credential=self.credential, api_version=self.api_version, _configuration=self._config,
_pipeline=_pipeline, _location_mode=self._location_mode, _hosts=self._hosts,
require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key,
- key_resolver_function=self.key_resolver_function)
+ key_resolver_function=self.key_resolver_function, msrest_xml=self._msrest_xml)
diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_list_blobs_helper.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_list_blobs_helper.py
index 058572fd270d..9a11087d7020 100644
--- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_list_blobs_helper.py
+++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_list_blobs_helper.py
@@ -7,14 +7,18 @@
from azure.core.async_paging import AsyncPageIterator, AsyncItemPaged
from azure.core.exceptions import HttpResponseError
-from .._deserialize import get_blob_properties_from_generated_code
-from .._models import BlobProperties
-from .._generated.models import BlobItemInternal, BlobPrefix as GenBlobPrefix
from .._shared.models import DictMixin
-from .._shared.response_handlers import return_context_and_deserialized, process_storage_error
+from .._shared.response_handlers import process_storage_error
+from .._list_blobs_helper import (
+ deserialize_list_result,
+ load_many_nodes,
+ load_xml_string,
+ load_xml_int,
+ blob_properties_from_xml
+)
-class BlobPropertiesPaged(AsyncPageIterator):
+class BlobPropertiesPaged(AsyncPageIterator): # pylint: disable=too-many-instance-attributes
"""An Iterable of Blob properties.
:ivar str service_endpoint: The service URL.
@@ -48,6 +52,8 @@ def __init__(
container=None,
prefix=None,
results_per_page=None,
+ select=None,
+ deserializer=None,
continuation_token=None,
delimiter=None,
location_mode=None):
@@ -57,10 +63,12 @@ def __init__(
continuation_token=continuation_token or ""
)
self._command = command
+ self._deserializer = deserializer
self.service_endpoint = None
self.prefix = prefix
self.marker = None
self.results_per_page = results_per_page
+ self.select = select
self.container = container
self.delimiter = delimiter
self.current_page = None
@@ -72,30 +80,29 @@ async def _get_next_cb(self, continuation_token):
prefix=self.prefix,
marker=continuation_token or None,
maxresults=self.results_per_page,
- cls=return_context_and_deserialized,
+ cls=deserialize_list_result,
use_location=self.location_mode)
except HttpResponseError as error:
process_storage_error(error)
async def _extract_data_cb(self, get_next_return):
self.location_mode, self._response = get_next_return
- self.service_endpoint = self._response.service_endpoint
- self.prefix = self._response.prefix
- self.marker = self._response.marker
- self.results_per_page = self._response.max_results
- self.container = self._response.container_name
- self.current_page = [self._build_item(item) for item in self._response.segment.blob_items]
+ self.service_endpoint = self._response.get('ServiceEndpoint')
+ self.prefix = load_xml_string(self._response, 'Prefix')
+ self.marker = load_xml_string(self._response, 'Marker')
+ self.results_per_page = load_xml_int(self._response, 'MaxResults')
+ self.container = self._response.get('ContainerName')
- return self._response.next_marker or None, self.current_page
+ blobs = load_many_nodes(self._response, 'Blob', wrapper='Blobs')
+ self.current_page = [self._build_item(blob) for blob in blobs]
+
+ next_marker = load_xml_string(self._response, 'NextMarker')
+ return next_marker or None, self.current_page
def _build_item(self, item):
- if isinstance(item, BlobProperties):
- return item
- if isinstance(item, BlobItemInternal):
- blob = get_blob_properties_from_generated_code(item) # pylint: disable=protected-access
- blob.container = self.container
- return blob
- return item
+ blob = blob_properties_from_xml(item, self.select, self._deserializer)
+ blob.container = self.container
+ return blob
class BlobPrefix(AsyncItemPaged, DictMixin):
@@ -144,20 +151,21 @@ def __init__(self, *args, **kwargs):
self.name = self.prefix
async def _extract_data_cb(self, get_next_return):
- continuation_token, _ = await super(BlobPrefixPaged, self)._extract_data_cb(get_next_return)
- self.current_page = self._response.segment.blob_prefixes + self._response.segment.blob_items
- self.current_page = [self._build_item(item) for item in self.current_page]
- self.delimiter = self._response.delimiter
-
+ continuation_token, current_page = await super(BlobPrefixPaged, self)._extract_data_cb(get_next_return)
+ blob_prefixes = load_many_nodes(self._response, 'BlobPrefix', wrapper='Blobs')
+ blob_prefixes = [self._build_prefix(blob) for blob in blob_prefixes]
+ self.current_page = blob_prefixes + current_page
+ self.delimiter = load_xml_string(self._response, 'Delimiter')
return continuation_token, self.current_page
- def _build_item(self, item):
- item = super(BlobPrefixPaged, self)._build_item(item)
- if isinstance(item, GenBlobPrefix):
- return BlobPrefix(
- self._command,
- container=self.container,
- prefix=item.name,
- results_per_page=self.results_per_page,
- location_mode=self.location_mode)
- return item
+ def _build_prefix(self, item):
+ return BlobPrefix(
+ self._command,
+ container=self.container,
+ prefix=load_xml_string(item, 'Name'),
+ results_per_page=self.results_per_page,
+ location_mode=self.location_mode,
+ select=self.select,
+ deserializer=self._deserializer,
+ delimiter=self.delimiter
+ )
diff --git a/sdk/storage/azure-storage-blob/tests/perfstress_tests/T1_legacy_tests/list_blobs.py b/sdk/storage/azure-storage-blob/tests/perfstress_tests/T1_legacy_tests/list_blobs.py
index b3a55bcf23b9..aedc564e0109 100644
--- a/sdk/storage/azure-storage-blob/tests/perfstress_tests/T1_legacy_tests/list_blobs.py
+++ b/sdk/storage/azure-storage-blob/tests/perfstress_tests/T1_legacy_tests/list_blobs.py
@@ -17,8 +17,12 @@ async def global_setup(self):
blob=b"")
def run_sync(self):
- for _ in self.service_client.list_blobs(container_name=self.container_name):
- pass
+ if self.args.name_only:
+ for _ in self.service_client.list_blob_names(container_name=self.container_name):
+ pass
+ else:
+ for _ in self.service_client.list_blobs(container_name=self.container_name):
+ pass
async def run_async(self):
raise NotImplementedError("Async not supported for legacy T1 tests.")
@@ -27,3 +31,4 @@ async def run_async(self):
def add_arguments(parser):
super(LegacyListBlobsTest, LegacyListBlobsTest).add_arguments(parser)
parser.add_argument('-c', '--count', nargs='?', type=int, help='Number of blobs to list. Defaults to 100', default=100)
+ parser.add_argument('--name-only', action='store_true', help='Return only blob name. Defaults to False', default=False)
diff --git a/sdk/storage/azure-storage-blob/tests/perfstress_tests/_test_base.py b/sdk/storage/azure-storage-blob/tests/perfstress_tests/_test_base.py
index ca46e67ffccb..8d9cdaf49829 100644
--- a/sdk/storage/azure-storage-blob/tests/perfstress_tests/_test_base.py
+++ b/sdk/storage/azure-storage-blob/tests/perfstress_tests/_test_base.py
@@ -25,6 +25,7 @@ def __init__(self, arguments):
self._client_kwargs['max_single_put_size'] = self.args.max_put_size
self._client_kwargs['max_block_size'] = self.args.max_block_size
self._client_kwargs['min_large_block_upload_threshold'] = self.args.buffer_threshold
+ self._client_kwargs['msrest_xml'] = not self.args.no_msrest
# self._client_kwargs['api_version'] = '2019-02-02' # Used only for comparison with T1 legacy tests
if not _ServiceTest.service_client or self.args.no_client_share:
@@ -46,6 +47,7 @@ def add_arguments(parser):
parser.add_argument('--max-concurrency', nargs='?', type=int, help='Maximum number of concurrent threads used for data transfer. Defaults to 1', default=1)
parser.add_argument('-s', '--size', nargs='?', type=int, help='Size of data to transfer. Default is 10240.', default=10240)
parser.add_argument('--no-client-share', action='store_true', help='Create one ServiceClient per test instance. Default is to share a single ServiceClient.', default=False)
+ parser.add_argument('--no-msrest', action='store_true', help='Do not use the msrest XML derialization pipeline. Defaults to False', default=False)
class _ContainerTest(_ServiceTest):
diff --git a/sdk/storage/azure-storage-blob/tests/perfstress_tests/list_blobs.py b/sdk/storage/azure-storage-blob/tests/perfstress_tests/list_blobs.py
index f5f35a86fff1..65894b044e3f 100644
--- a/sdk/storage/azure-storage-blob/tests/perfstress_tests/list_blobs.py
+++ b/sdk/storage/azure-storage-blob/tests/perfstress_tests/list_blobs.py
@@ -27,14 +27,23 @@ async def global_setup(self):
break
def run_sync(self):
- for _ in self.container_client.list_blobs():
- pass
+ if self.args.name_only:
+ for _ in self.container_client.list_blob_names():
+ pass
+ else:
+ for _ in self.container_client.list_blobs():
+ pass
async def run_async(self):
- async for _ in self.async_container_client.list_blobs():
- pass
+ if self.args.name_only:
+ async for _ in self.async_container_client.list_blob_names():
+ pass
+ else:
+ async for _ in self.async_container_client.list_blobs():
+ pass
@staticmethod
def add_arguments(parser):
super(ListBlobsTest, ListBlobsTest).add_arguments(parser)
parser.add_argument('-c', '--count', nargs='?', type=int, help='Number of blobs to list. Defaults to 100', default=100)
+ parser.add_argument('--name-only', action='store_true', help='Return only blob name. Defaults to False', default=False)
diff --git a/sdk/storage/azure-storage-blob/tests/test_blob_service_stats.py b/sdk/storage/azure-storage-blob/tests/test_blob_service_stats.py
index 1de16c8a6538..2a36c6b8f9c2 100644
--- a/sdk/storage/azure-storage-blob/tests/test_blob_service_stats.py
+++ b/sdk/storage/azure-storage-blob/tests/test_blob_service_stats.py
@@ -12,13 +12,13 @@
from _shared.testcase import GlobalStorageAccountPreparer, GlobalResourceGroupPreparer
-SERVICE_UNAVAILABLE_RESP_BODY = 'unavailable '
+SERVICE_UNAVAILABLE_RESP_BODY = b'unavailable '
-SERVICE_LIVE_RESP_BODY = 'liveWed, 19 Jan 2021 22:28:43 GMT '
+SERVICE_LIVE_RESP_BODY = b'liveWed, 19 Jan 2021 22:28:43 GMT '
# --Test Class -----------------------------------------------------------------
class ServiceStatsTest(StorageTestCase):
@@ -39,11 +39,11 @@ def _assert_stats_unavailable(self, stats):
@staticmethod
def override_response_body_with_live_status(response):
- response.http_response.text = lambda encoding=None: SERVICE_LIVE_RESP_BODY
+ response.http_response.body = lambda: SERVICE_LIVE_RESP_BODY
@staticmethod
def override_response_body_with_unavailable_status(response):
- response.http_response.text = lambda encoding=None: SERVICE_UNAVAILABLE_RESP_BODY
+ response.http_response.body = lambda: SERVICE_UNAVAILABLE_RESP_BODY
# --Test cases per service ---------------------------------------
@GlobalResourceGroupPreparer()
diff --git a/sdk/storage/azure-storage-blob/tests/test_blob_service_stats_async.py b/sdk/storage/azure-storage-blob/tests/test_blob_service_stats_async.py
index 380fa67b024d..4b545a2eafab 100644
--- a/sdk/storage/azure-storage-blob/tests/test_blob_service_stats_async.py
+++ b/sdk/storage/azure-storage-blob/tests/test_blob_service_stats_async.py
@@ -15,14 +15,14 @@
from devtools_testutils.storage.aio import AsyncStorageTestCase
-SERVICE_UNAVAILABLE_RESP_BODY = 'unavailable '
+SERVICE_UNAVAILABLE_RESP_BODY = b'unavailable '
-SERVICE_LIVE_RESP_BODY = 'liveWed, 19 Jan 2021 22:28:43 GMT '
+SERVICE_LIVE_RESP_BODY = b'liveWed, 19 Jan 2021 22:28:43 GMT '
class AiohttpTestTransport(AioHttpTransport):
@@ -55,11 +55,11 @@ def _assert_stats_unavailable(self, stats):
@staticmethod
def override_response_body_with_live_status(response):
- response.http_response.text = lambda encoding=None: SERVICE_LIVE_RESP_BODY
+ response.http_response.body = lambda: SERVICE_LIVE_RESP_BODY
@staticmethod
def override_response_body_with_unavailable_status(response):
- response.http_response.text = lambda encoding=None: SERVICE_UNAVAILABLE_RESP_BODY
+ response.http_response.body = lambda: SERVICE_UNAVAILABLE_RESP_BODY
# --Test cases per service ---------------------------------------
@GlobalResourceGroupPreparer()