Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add read method to StorageStreamDownloader #24275

Merged
merged 16 commits into from
Aug 9, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/storage/azure-storage-blob/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ This version and all future versions will require Python 3.7+. Python 3.6 is no

### Features Added
- Added support for `AzureNamedKeyCredential` as a valid `credential` type.
- Added standard `read` method to `StorageStreamDownloader`.
- Added support for async streams (classes with an async `read` method) to async `upload_blob`.

### Bugs Fixed
- Removed dead retry meachism from async `azure.storage.blob.aio.StorageStreamDownloader`.
Expand Down
110 changes: 98 additions & 12 deletions sdk/storage/azure-storage-blob/azure/storage/blob/_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
import warnings
from io import BytesIO
from typing import Generic, Iterator, TypeVar
from typing import Generic, Iterator, Optional, TypeVar

from azure.core.exceptions import DecodeError, HttpResponseError, IncompleteReadError
from azure.core.tracing.common import with_current_context
Expand Down Expand Up @@ -334,6 +334,7 @@ def __init__(
self._non_empty_ranges = None
self._response = None
self._encryption_data = None
self._offset = 0
jalauzon-msft marked this conversation as resolved.
Show resolved Hide resolved

# The cls is passed in via download_cls to avoid conflicting arg name with Generic.__new__
# but needs to be changed to cls in the request options.
Expand Down Expand Up @@ -504,6 +505,17 @@ def _initial_request(self):

return response

def _get_downloader_start_with_offset(self):
# Start where the initial request download ended
start = self._initial_range[1] + 1
# For encryption V2 only, adjust start to the end of the fetched data rather than download size
if self._encryption_options.get("key") is not None or self._encryption_options.get("resolver") is not None:
start = (self._start_range or 0) + len(self._current_content)

# Adjust the start based on any data read past the current content
start += (self._offset - len(self._current_content))
return start

def chunks(self):
# type: () -> Iterator[bytes]
"""Iterate over chunks in the download stream.
Expand Down Expand Up @@ -554,6 +566,73 @@ def chunks(self):
downloader=iter_downloader,
chunk_size=self._config.max_chunk_get_size)

def read(self, size: Optional[int] = -1) -> T:
"""
Read up to size bytes from the object and return them. If size
is specified as -1, all bytes will be read.
"""
if size == -1:
return self.readall()
# Empty blob or already read to the end
if size == 0 or self._offset >= self.size:
return b'' if not self._encoding else ''

stream = BytesIO()
remaining_size = size

# Start by reading from current_content if there is data left
if self._offset < len(self._current_content):
start = self._offset
end = min(remaining_size, len(self._current_content) - self._offset)
read = stream.write(self._current_content[start:end])

remaining_size -= read
self._offset += read

if remaining_size > 0:
start_range = self._get_downloader_start_with_offset()

# End is the min between the remaining size, the file size, and the end of the specified range
end_range = min(start_range + remaining_size, self._file_size)
if self._end_range is not None:
end_range = min(end_range, self._end_range + 1)

parallel = self._max_concurrency > 1
downloader = _ChunkDownloader(
client=self._clients.blob,
non_empty_ranges=self._non_empty_ranges,
total_size=self.size,
chunk_size=self._config.max_chunk_get_size,
current_progress=self._offset,
start_range=start_range,
end_range=end_range,
stream=stream,
parallel=parallel,
validate_content=self._validate_content,
encryption_options=self._encryption_options,
encryption_data=self._encryption_data,
use_location=self._location_mode,
**self._request_options
)

if parallel and remaining_size > self._config.max_chunk_get_size:
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(self._max_concurrency) as executor:
list(executor.map(
with_current_context(downloader.process_chunk),
downloader.get_chunk_offsets()
))
else:
for chunk in downloader.get_chunk_offsets():
downloader.process_chunk(chunk)

self._offset += remaining_size

data = stream.getvalue()
if self._encoding:
return data.decode(self._encoding)
return data

def readall(self):
# type: () -> T
"""Download the contents of this blob.
Expand Down Expand Up @@ -625,30 +704,36 @@ def readinto(self, stream):
except (NotImplementedError, AttributeError):
raise ValueError(error_message)

# Write the content to the user stream
stream.write(self._current_content)
if self._progress_hook:
self._progress_hook(len(self._current_content), self.size)
# If some data has been streamed using `read`, only stream the remaining data
remaining_size = self.size - self._offset
# Already read to the end
if remaining_size <= 0:
return 0

# Write the content to the user stream if there is data left
if self._offset < len(self._current_content):
content = self._current_content[self._offset:]
stream.write(content)
self._offset += len(content)
if self._progress_hook:
self._progress_hook(len(content), self.size)

if self._download_complete:
return self.size
return remaining_size

data_end = self._file_size
if self._end_range is not None:
# Use the length unless it is over the end of the file
data_end = min(self._file_size, self._end_range + 1)

data_start = self._initial_range[1] + 1 # Start where the first download ended
# For encryption, adjust start to the end of the fetched data rather than download size
if self._encryption_options.get("key") is not None or self._encryption_options.get("resolver") is not None:
data_start = (self._start_range or 0) + len(self._current_content)
data_start = self._get_downloader_start_with_offset()

downloader = _ChunkDownloader(
client=self._clients.blob,
non_empty_ranges=self._non_empty_ranges,
total_size=self.size,
chunk_size=self._config.max_chunk_get_size,
current_progress=self._first_get_size,
current_progress=self._offset,
start_range=data_start,
end_range=data_end,
stream=stream,
Expand All @@ -670,7 +755,8 @@ def readinto(self, stream):
else:
for chunk in downloader.get_chunk_offsets():
downloader.process_chunk(chunk)
return self.size

return remaining_size

def download_to_stream(self, stream, max_concurrency=1):
"""Download the contents of this blob to a stream.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,26 @@
from .uploads import SubStream, IterStreamer # pylint: disable=unused-import


async def _async_parallel_uploads(uploader, pending, running):
range_ids = []
while True:
# Wait for some download to finish before adding a new one
done, running = await asyncio.wait(running, return_when=asyncio.FIRST_COMPLETED)
range_ids.extend([chunk.result() for chunk in done])
try:
for _ in range(0, len(done)):
next_chunk = await pending.__anext__()
running.add(asyncio.ensure_future(uploader(next_chunk)))
except StopAsyncIteration:
break

# Wait for the remaining uploads to finish
if running:
done, _running = await asyncio.wait(running)
range_ids.extend([chunk.result() for chunk in done])
annatisch marked this conversation as resolved.
Show resolved Hide resolved
return range_ids


async def _parallel_uploads(uploader, pending, running):
range_ids = []
while True:
Expand Down Expand Up @@ -65,14 +85,18 @@ async def upload_data_chunks(

if parallel:
upload_tasks = uploader.get_chunk_streams()
running_futures = [
asyncio.ensure_future(uploader.process_chunk(u))
for u in islice(upload_tasks, 0, max_concurrency)
]
range_ids = await _parallel_uploads(uploader.process_chunk, upload_tasks, running_futures)
running_futures = []
for _ in range(max_concurrency):
try:
chunk = await upload_tasks.__anext__()
running_futures.append(asyncio.ensure_future(uploader.process_chunk(chunk)))
except StopAsyncIteration:
break

range_ids = await _async_parallel_uploads(uploader.process_chunk, upload_tasks, running_futures)
else:
range_ids = []
for chunk in uploader.get_chunk_streams():
async for chunk in uploader.get_chunk_streams():
range_ids.append(await uploader.process_chunk(chunk))

if any(range_ids):
Expand Down Expand Up @@ -152,7 +176,7 @@ def __init__(
self.last_modified = None
self.request_options = kwargs

def get_chunk_streams(self):
async def get_chunk_streams(self):
index = 0
while True:
data = b''
Expand All @@ -162,7 +186,10 @@ def get_chunk_streams(self):
while True:
if self.total_size:
read_size = min(self.chunk_size - len(data), self.total_size - (index + len(data)))
temp = self.stream.read(read_size)
if asyncio.iscoroutinefunction(self.stream.read):
temp = await self.stream.read(read_size)
else:
temp = self.stream.read(read_size)
if not isinstance(temp, six.binary_type):
raise TypeError('Blob data should be of type bytes.')
data += temp or b""
Expand Down
Loading