Skip to content

Commit

Permalink
Azure blob storage (#2284)
Browse files Browse the repository at this point in the history
* Add support of azure blob storage provider

* Add new providers to docs

* Add changelog

* Fix tests
  • Loading branch information
romasku authored Sep 2, 2021
1 parent fbb37ee commit 702098b
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.D/2284.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added support of blob storage for azure clusters.
8 changes: 8 additions & 0 deletions neuro-sdk/docs/buckets_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,14 @@ Bucket.Provider

Amazon Web Services S3 bucket

.. attribute:: MINIO

Minio S3 bucket

.. attribute:: AZURE

Azure blob storage container


PersistentBucketCredentials
===========================
Expand Down
1 change: 1 addition & 0 deletions neuro-sdk/docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ globbing
grpc
iterable
login
minio
namespace
neu
neuro
Expand Down
1 change: 1 addition & 0 deletions neuro-sdk/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"certifi",
"toml>=0.10.0",
"aiobotocore>=1.3.3",
"azure-storage-blob>=12.8.1",
],
include_package_data=True,
description="Neu.ro SDK",
Expand Down
191 changes: 178 additions & 13 deletions neuro-sdk/src/neuro_sdk/buckets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import abc
import asyncio
import enum
import json
import logging
import re
import secrets
import sys
import time
from dataclasses import dataclass
from datetime import datetime
from datetime import datetime, timedelta, timezone
from email.utils import parsedate_to_datetime
from pathlib import PurePosixPath
from typing import (
Expand All @@ -27,6 +29,11 @@
import botocore.exceptions
from aiobotocore.client import AioBaseClient
from aiobotocore.credentials import AioRefreshableCredentials
from azure.core.credentials import AzureSasCredential
from azure.core.exceptions import ResourceNotFoundError
from azure.storage.blob import BlobBlock
from azure.storage.blob.aio import ContainerClient
from azure.storage.blob.aio._list_blobs_helper import BlobPrefix
from dateutil.parser import isoparse
from yarl import URL

Expand Down Expand Up @@ -264,19 +271,15 @@ def child(self, path: PurePosixPath, child: str) -> PurePosixPath:
return path / child


class S3Provider(BucketProvider):
def __init__(
self, client: AioBaseClient, bucket: "Bucket", bucket_name: str
) -> None:
self.bucket = bucket
client._make_api_call = self._wrap_api_call(client._make_api_call)
self._client = client
class MeasureTimeDiffMixin:
def __init__(self) -> None:
self._min_time_diff: Optional[float] = 0
self._max_time_diff: Optional[float] = 0
self._bucket_name = bucket_name

def _wrap_api_call(
self, _make_call: Callable[..., Awaitable[Any]]
self,
_make_call: Callable[..., Awaitable[Any]],
get_date: Callable[[Any], datetime],
) -> Callable[..., Awaitable[Any]]:
def _average(cur_approx: Optional[float], new_val: float) -> float:
if cur_approx is None:
Expand All @@ -288,9 +291,8 @@ async def _wrapper(*args: Any, **kwargs: Any) -> Any:
res = await _make_call(*args, **kwargs)
after = time.time()
try:
date_str = res["ResponseMetadata"]["HTTPHeaders"]["date"]
server_dt = parsedate_to_datetime(date_str)
except (KeyError, TypeError, ValueError):
server_dt = get_date(res)
except Exception:
pass
else:
server_time = server_dt.timestamp()
Expand All @@ -308,6 +310,29 @@ async def _wrapper(*args: Any, **kwargs: Any) -> Any:

return _wrapper

async def get_time_diff_to_local(self) -> Tuple[float, float]:
if self._min_time_diff is None or self._max_time_diff is None:
return 0, 0
return self._min_time_diff, self._max_time_diff


class S3Provider(MeasureTimeDiffMixin, BucketProvider):
def __init__(
self, client: AioBaseClient, bucket: "Bucket", bucket_name: str
) -> None:
super().__init__()
self.bucket = bucket
self._client = client
self._bucket_name = bucket_name

def _extract_date(resp: Any) -> datetime:
date_str = resp["ResponseMetadata"]["HTTPHeaders"]["date"]
return parsedate_to_datetime(date_str)

client._make_api_call = self._wrap_api_call(
client._make_api_call, _extract_date
)

@classmethod
@asynccontextmanager
async def create(
Expand Down Expand Up @@ -461,6 +486,141 @@ async def fetch_blob(self, key: str, offset: int = 0) -> AsyncIterator[bytes]:
async def delete_blob(self, key: str) -> None:
await self._client.delete_object(Bucket=self._bucket_name, Key=key)


class AzureProvider(MeasureTimeDiffMixin, BucketProvider):
def __init__(self, container_client: ContainerClient, bucket: "Bucket") -> None:
super().__init__()
self.bucket = bucket

self._client = container_client

def _extract_date(resp: Any) -> datetime:
date_str = resp.http_response.headers["Date"]
return parsedate_to_datetime(date_str)

# Hack to get client-server clock difference
container_client._client._client._pipeline.run = self._wrap_api_call(
container_client._client._client._pipeline.run, _extract_date
)

@classmethod
@asynccontextmanager
async def create(
cls,
bucket: "Bucket",
_get_credentials: Callable[[], Awaitable["BucketCredentials"]],
) -> AsyncIterator["AzureProvider"]:
initial_credentials = await _get_credentials()

sas_credential = AzureSasCredential(
initial_credentials.credentials["sas_token"]
)

@asynccontextmanager
async def _token_renewer() -> AsyncIterator[None]:
async def renew_token_loop() -> None:
expiry = isoparse(initial_credentials.credentials["expiry"])
while True:
delay = (
expiry - timedelta(minutes=10) - datetime.now(timezone.utc)
).total_seconds()
await asyncio.sleep(max(delay, 0))
credentials = await _get_credentials()
sas_credential.update(credentials.credentials["sas_token"])
expiry = isoparse(credentials.credentials["expiry"])

task = asyncio.ensure_future(renew_token_loop())
try:
yield
finally:
task.cancel()

async with ContainerClient(
account_url=initial_credentials.credentials["storage_endpoint"],
container_name=initial_credentials.credentials["bucket_name"],
credential=sas_credential,
) as container_client, _token_renewer():
yield cls(container_client, bucket)

@asyncgeneratorcontextmanager
async def list_blobs(
self, prefix: str, recursive: bool = False, limit: Optional[int] = None
) -> AsyncIterator[BucketEntry]:
if recursive:
it = self._client.list_blobs(prefix)
else:
it = self._client.walk_blobs(prefix)
count = 0
async for item in it:
if isinstance(item, BlobPrefix):
entry: BucketEntry = BlobCommonPrefix(
bucket=self.bucket,
key=item.name,
size=0,
)
else:
entry = BlobObject(
bucket=self.bucket,
key=item.name,
size=item.size,
created_at=item.creation_time,
modified_at=item.last_modified,
)
yield entry
count += 1
if count == limit:
return

async def head_blob(self, key: str) -> BucketEntry:
try:
blob_info = await self._client.get_blob_client(key).get_blob_properties()
return BlobObject(
bucket=self.bucket,
key=blob_info.name,
size=blob_info.size,
created_at=blob_info.creation_time,
modified_at=blob_info.last_modified,
)
except ResourceNotFoundError:
raise ResourceNotFound(
f"There is no object with key {key} in bucket {self.bucket.name}"
)

async def put_blob(
self, key: str, body: Union[AsyncIterator[bytes], bytes]
) -> None:
blob_client = self._client.get_blob_client(key)
if isinstance(body, bytes):
await blob_client.upload_blob(body)
else:
blocks = []
async for data in body:
block_id = secrets.token_hex(16)
await blob_client.stage_block(block_id, data)
blocks.append(BlobBlock(block_id=block_id))
await blob_client.commit_block_list(blocks)

@asyncgeneratorcontextmanager
async def fetch_blob(self, key: str, offset: int = 0) -> AsyncIterator[bytes]:
try:
downloader = await self._client.get_blob_client(key).download_blob(
offset=offset
)
except ResourceNotFoundError:
raise ResourceNotFound(
f"There is no object with key {key} in bucket {self.bucket.name}"
)
async for chunk in downloader.chunks():
yield chunk

async def delete_blob(self, key: str) -> None:
try:
await self._client.get_blob_client(key).delete_blob()
except ResourceNotFoundError:
raise ResourceNotFound(
f"There is no object with key {key} in bucket {self.bucket.name}"
)

async def get_time_diff_to_local(self) -> Tuple[float, float]:
if self._min_time_diff is None or self._max_time_diff is None:
return 0, 0
Expand All @@ -483,6 +643,7 @@ def uri(self) -> URL:
class Provider(str, enum.Enum):
AWS = "aws"
MINIO = "minio"
AZURE = "azure"


@dataclass(frozen=True)
Expand Down Expand Up @@ -603,9 +764,13 @@ async def _get_provider(
async def _get_new_credentials() -> BucketCredentials:
return await self.request_tmp_credentials(bucket_id_or_name, cluster_name)

provider: BucketProvider
if bucket.provider in (Bucket.Provider.AWS, Bucket.Provider.MINIO):
async with S3Provider.create(bucket, _get_new_credentials) as provider:
yield provider
elif bucket.provider == Bucket.Provider.AZURE:
async with AzureProvider.create(bucket, _get_new_credentials) as provider:
yield provider
else:
assert False, f"Unknown provider {bucket.provider}"

Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,6 @@ ignore_missing_imports = true

[mypy-botocore.*]
ignore_missing_imports = true

[mypy-azure.*]
ignore_missing_imports = true

0 comments on commit 702098b

Please sign in to comment.