Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[POC] Blobs Partial list deserialization #19814

Closed
wants to merge 15 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
validate_and_format_range_headers)
from ._shared.response_handlers import return_response_headers, process_storage_error, return_headers_and_deserialized
from ._generated import AzureBlobStorage
from ._generated import models as generated_models
from ._generated.models import ( # pylint: disable=unused-import
DeleteSnapshotsOptionType,
BlobHTTPHeaders,
Expand Down Expand Up @@ -175,6 +176,8 @@ def __init__(
self._query_str, credential = self._format_query_string(sas_token, credential, snapshot=self.snapshot)
super(BlobClient, self).__init__(parsed_url, service='blob', credential=credential, **kwargs)
self._client = AzureBlobStorage(self.url, pipeline=self._pipeline)
if not self._msrest_xml:
self._custom_xml_deserializer(generated_models)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removable. Same does for the all the clients.

default_api_version = self._client._config.version # pylint: disable=protected-access
self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ._shared.response_handlers import return_response_headers, process_storage_error, \
parse_to_internal_user_delegation_key
from ._generated import AzureBlobStorage
from ._generated import models as generated_models
from ._generated.models import StorageServiceProperties, KeyInfo
from ._container_client import ContainerClient
from ._blob_client import BlobClient
Expand Down Expand Up @@ -134,6 +135,8 @@ def __init__(
self._query_str, credential = self._format_query_string(sas_token, credential)
super(BlobServiceClient, self).__init__(parsed_url, service='blob', credential=credential, **kwargs)
self._client = AzureBlobStorage(self.url, pipeline=self._pipeline)
if not self._msrest_xml:
self._custom_xml_deserializer(generated_models)
default_api_version = self._client._config.version # pylint: disable=protected-access
self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access

Expand Down Expand Up @@ -676,7 +679,7 @@ def get_container_client(self, container):
credential=self.credential, api_version=self.api_version, _configuration=self._config,
_pipeline=_pipeline, _location_mode=self._location_mode, _hosts=self._hosts,
require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key,
key_resolver_function=self.key_resolver_function)
key_resolver_function=self.key_resolver_function, msrest_xml=self._msrest_xml)

def get_blob_client(
self, container, # type: Union[ContainerProperties, str]
Expand Down Expand Up @@ -729,4 +732,4 @@ def get_blob_client(
credential=self.credential, api_version=self.api_version, _configuration=self._config,
_pipeline=_pipeline, _location_mode=self._location_mode, _hosts=self._hosts,
require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key,
key_resolver_function=self.key_resolver_function)
key_resolver_function=self.key_resolver_function, msrest_xml=self._msrest_xml)
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
return_response_headers,
return_headers_and_deserialized)
from ._generated import AzureBlobStorage
from ._generated import models as generated_models
from ._generated.models import SignedIdentifier
from ._deserialize import deserialize_container_properties
from ._serialize import get_modify_conditions, get_container_cpk_scope_info, get_api_version, get_access_conditions
Expand Down Expand Up @@ -156,6 +157,8 @@ def __init__(
self._query_str, credential = self._format_query_string(sas_token, credential)
super(ContainerClient, self).__init__(parsed_url, service='blob', credential=credential, **kwargs)
self._client = AzureBlobStorage(self.url, pipeline=self._pipeline)
if not self._msrest_xml:
self._custom_xml_deserializer(generated_models)
default_api_version = self._client._config.version # pylint: disable=protected-access
self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access

Expand Down Expand Up @@ -769,7 +772,41 @@ def list_blobs(self, name_starts_with=None, include=None, **kwargs):
timeout=timeout,
**kwargs)
return ItemPaged(
command, prefix=name_starts_with, results_per_page=results_per_page,
command,
prefix=name_starts_with,
results_per_page=results_per_page,
select=None,
deserializer=self._client._deserialize, # pylint: disable=protected-access
page_iterator_class=BlobPropertiesPaged)

@distributed_trace
def list_blob_names(self, **kwargs):
# type: (**Any) -> ItemPaged[str]
"""Returns a generator to list the names of blobs under the specified container.
The generator will lazily follow the continuation tokens returned by
the service.

:keyword str name_starts_with:
Filters the results to return only blobs whose names
begin with the specified prefix.
:keyword int timeout:
The timeout parameter is expressed in seconds.
:returns: An iterable (auto-paging) response of blob names as strings.
:rtype: ~azure.core.paging.ItemPaged[str]
"""
name_starts_with = kwargs.pop('name_starts_with', None)
results_per_page = kwargs.pop('results_per_page', None)
timeout = kwargs.pop('timeout', None)
command = functools.partial(
self._client.container.list_blob_flat_segment,
timeout=timeout,
**kwargs)
return ItemPaged(
command,
prefix=name_starts_with,
results_per_page=results_per_page,
select=["name"],
deserializer=self._client._deserialize, # pylint: disable=protected-access
page_iterator_class=BlobPropertiesPaged)

@distributed_trace
Expand Down Expand Up @@ -816,6 +853,8 @@ def walk_blobs(
command,
prefix=name_starts_with,
results_per_page=results_per_page,
select=None,
deserializer=self._client._deserialize, # pylint: disable=protected-access
delimiter=delimiter)

@distributed_trace
Expand Down Expand Up @@ -1548,4 +1587,4 @@ def get_blob_client(
credential=self.credential, api_version=self.api_version, _configuration=self._config,
_pipeline=_pipeline, _location_mode=self._location_mode, _hosts=self._hosts,
require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key,
key_resolver_function=self.key_resolver_function)
key_resolver_function=self.key_resolver_function, msrest_xml=self._msrest_xml)
Original file line number Diff line number Diff line change
Expand Up @@ -1452,12 +1452,12 @@ async def list_blob_flat_segment(
response_headers['x-ms-request-id']=self._deserialize('str', response.headers.get('x-ms-request-id'))
response_headers['x-ms-version']=self._deserialize('str', response.headers.get('x-ms-version'))
response_headers['Date']=self._deserialize('rfc-1123', response.headers.get('Date'))
deserialized = self._deserialize('ListBlobsFlatSegmentResponse', pipeline_response)
#deserialized = self._deserialize('ListBlobsFlatSegmentResponse', pipeline_response)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is one of the bigger challenges to figure out. Currently the deserialization process is out of our hands. We would need to add directives to the autorest code gen for the list_blobs_flat_segment API to not deserialize the response payload. This should be possible by simply overwriting the output model to have no output. We then use the cls hook and do the deserialization ourselves.
In the case of the existing list_blobs API, this probably means manually using the existing msrest deserializer if we don't want to deal with the testing burden of validating the new deserializer for the old API.


if cls:
return cls(pipeline_response, deserialized, response_headers)
return cls(pipeline_response, None, response_headers)

return deserialized
return None
list_blob_flat_segment.metadata = {'url': '/{containerName}'} # type: ignore

async def list_blob_hierarchy_segment(
Expand Down Expand Up @@ -1564,12 +1564,12 @@ async def list_blob_hierarchy_segment(
response_headers['x-ms-request-id']=self._deserialize('str', response.headers.get('x-ms-request-id'))
response_headers['x-ms-version']=self._deserialize('str', response.headers.get('x-ms-version'))
response_headers['Date']=self._deserialize('rfc-1123', response.headers.get('Date'))
deserialized = self._deserialize('ListBlobsHierarchySegmentResponse', pipeline_response)
# deserialized = self._deserialize('ListBlobsHierarchySegmentResponse', pipeline_response)

if cls:
return cls(pipeline_response, deserialized, response_headers)
return cls(pipeline_response, None, response_headers)

return deserialized
return None
list_blob_hierarchy_segment.metadata = {'url': '/{containerName}'} # type: ignore

async def get_account_info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1471,12 +1471,12 @@ def list_blob_flat_segment(
response_headers['x-ms-request-id']=self._deserialize('str', response.headers.get('x-ms-request-id'))
response_headers['x-ms-version']=self._deserialize('str', response.headers.get('x-ms-version'))
response_headers['Date']=self._deserialize('rfc-1123', response.headers.get('Date'))
deserialized = self._deserialize('ListBlobsFlatSegmentResponse', pipeline_response)
#deserialized = self._deserialize('ListBlobsFlatSegmentResponse', pipeline_response)

if cls:
return cls(pipeline_response, deserialized, response_headers)
return cls(pipeline_response, None, response_headers)

return deserialized
return None # deserialized
list_blob_flat_segment.metadata = {'url': '/{containerName}'} # type: ignore

def list_blob_hierarchy_segment(
Expand Down Expand Up @@ -1584,12 +1584,12 @@ def list_blob_hierarchy_segment(
response_headers['x-ms-request-id']=self._deserialize('str', response.headers.get('x-ms-request-id'))
response_headers['x-ms-version']=self._deserialize('str', response.headers.get('x-ms-version'))
response_headers['Date']=self._deserialize('rfc-1123', response.headers.get('Date'))
deserialized = self._deserialize('ListBlobsHierarchySegmentResponse', pipeline_response)
# deserialized = self._deserialize('ListBlobsHierarchySegmentResponse', pipeline_response)

if cls:
return cls(pipeline_response, deserialized, response_headers)
return cls(pipeline_response, None, response_headers)

return deserialized
return None # deserialized
list_blob_hierarchy_segment.metadata = {'url': '/{containerName}'} # type: ignore

def get_account_info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,70 @@

from azure.core.paging import PageIterator, ItemPaged
from azure.core.exceptions import HttpResponseError

from ._deserialize import get_blob_properties_from_generated_code, parse_tags
from ._generated.models import BlobItemInternal, BlobPrefix as GenBlobPrefix, FilterBlobItem
from ._generated.models import FilterBlobItem
from ._models import BlobProperties, FilteredBlob
from ._shared.models import DictMixin
from ._shared.xml_deserialization import unpack_xml_content
from ._shared.response_handlers import return_context_and_deserialized, process_storage_error


class BlobPropertiesPaged(PageIterator):
def deserialize_list_result(pipeline_response, *_):
payload = unpack_xml_content(pipeline_response.http_response)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this line here is replacing the ContentDecodePolicy that I removed from the pipeline. So this would already be unpacked if we put that policy back in.

location = pipeline_response.http_response.location_mode
return location, payload


def load_xml_string(element, name):
node = element.find(name)
if node is None or not node.text:
return None
return node.text


def load_xml_int(element, name):
node = element.find(name)
if node is None or not node.text:
return None
return int(node.text)


def load_xml_bool(element, name):
node = load_xml_string(element, name)
if node and node.lower() == 'true':
return True
return False


def load_single_node(element, name):
return element.find(name)


def load_many_nodes(element, name, wrapper=None):
if wrapper:
element = load_single_node(element, wrapper)
return list(element.findall(name))


def blob_properties_from_xml(element, select, deserializer):
if not select:
generated = deserializer.deserialize_data(element, 'BlobItemInternal')
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is using the old msrest deserializer - so once we've altered the generated layer to not deserialize for us - keeping this should mean that the existing list_blobs doesn't change.

return get_blob_properties_from_generated_code(generated)
blob = BlobProperties()
if 'name' in select:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented this select logic in case we wanted to return more from the payload than just the name. However that seems unlikely - so we could refactor this out and simplify the logic a big here.

blob.name = load_xml_string(element, 'Name')
if 'deleted' in select:
blob.deleted = load_xml_bool(element, 'Deleted')
if 'snapshot' in select:
blob.snapshot = load_xml_string(element, 'Snapshot')
if 'version' in select:
blob.version_id = load_xml_string(element, 'VersionId')
blob.is_current_version = load_xml_bool(element, 'IsCurrentVersion')
return blob


class BlobPropertiesPaged(PageIterator): # pylint: disable=too-many-instance-attributes
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like I updated the existing BlobPropertiesPaged - which means the perf of this model would be improved, however if we wanted to leave the original list_blobs API completely untouched, we could revert the changes here and have the new list_blob_names API use it's own custom Paged object.

"""An Iterable of Blob properties.

:ivar str service_endpoint: The service URL.
Expand Down Expand Up @@ -49,6 +105,8 @@ def __init__(
container=None,
prefix=None,
results_per_page=None,
select=None,
deserializer=None,
continuation_token=None,
delimiter=None,
location_mode=None):
Expand All @@ -58,10 +116,12 @@ def __init__(
continuation_token=continuation_token or ""
)
self._command = command
self._deserializer = deserializer
self.service_endpoint = None
self.prefix = prefix
self.marker = None
self.results_per_page = results_per_page
self.select = select
self.container = container
self.delimiter = delimiter
self.current_page = None
Expand All @@ -73,30 +133,29 @@ def _get_next_cb(self, continuation_token):
prefix=self.prefix,
marker=continuation_token or None,
maxresults=self.results_per_page,
cls=return_context_and_deserialized,
cls=deserialize_list_result,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the cls parameter I mentioned that we would use to hook into the deserialization.

use_location=self.location_mode)
except HttpResponseError as error:
process_storage_error(error)

def _extract_data_cb(self, get_next_return):
self.location_mode, self._response = get_next_return
self.service_endpoint = self._response.service_endpoint
self.prefix = self._response.prefix
self.marker = self._response.marker
self.results_per_page = self._response.max_results
self.container = self._response.container_name
self.current_page = [self._build_item(item) for item in self._response.segment.blob_items]
self.service_endpoint = self._response.get('ServiceEndpoint')
self.prefix = load_xml_string(self._response, 'Prefix')
self.marker = load_xml_string(self._response, 'Marker')
self.results_per_page = load_xml_int(self._response, 'MaxResults')
self.container = self._response.get('ContainerName')

return self._response.next_marker or None, self.current_page
blobs = load_many_nodes(self._response, 'Blob', wrapper='Blobs')
self.current_page = [self._build_item(blob) for blob in blobs]

next_marker = load_xml_string(self._response, 'NextMarker')
return next_marker or None, self.current_page

def _build_item(self, item):
if isinstance(item, BlobProperties):
return item
if isinstance(item, BlobItemInternal):
blob = get_blob_properties_from_generated_code(item) # pylint: disable=protected-access
blob.container = self.container
return blob
return item
blob = blob_properties_from_xml(item, self.select, self._deserializer)
blob.container = self.container
return blob


class BlobPrefixPaged(BlobPropertiesPaged):
Expand All @@ -106,22 +165,26 @@ def __init__(self, *args, **kwargs):

def _extract_data_cb(self, get_next_return):
continuation_token, _ = super(BlobPrefixPaged, self)._extract_data_cb(get_next_return)
self.current_page = self._response.segment.blob_prefixes + self._response.segment.blob_items
self.current_page = [self._build_item(item) for item in self.current_page]
self.delimiter = self._response.delimiter

blob_prefixes = load_many_nodes(self._response, 'BlobPrefix', wrapper='Blobs')
blob_prefixes = [self._build_prefix(blob) for blob in blob_prefixes]

self.current_page = blob_prefixes + self.current_page
self.delimiter = load_xml_string(self._response, 'Delimiter')

return continuation_token, self.current_page

def _build_item(self, item):
item = super(BlobPrefixPaged, self)._build_item(item)
if isinstance(item, GenBlobPrefix):
return BlobPrefix(
self._command,
container=self.container,
prefix=item.name,
results_per_page=self.results_per_page,
location_mode=self.location_mode)
return item
def _build_prefix(self, item):
return BlobPrefix(
self._command,
container=self.container,
prefix=load_xml_string(item, 'Name'),
results_per_page=self.results_per_page,
location_mode=self.location_mode,
select=self.select,
deserializer=self._deserializer,
delimiter=self.delimiter
)


class BlobPrefix(ItemPaged, DictMixin):
Expand Down
Loading