Skip to content

Commit

Permalink
[Storage] Remove client-side encryption code from shared (#24931)
Browse files Browse the repository at this point in the history
  • Loading branch information
jalauzon-msft authored Jun 25, 2022
1 parent 10304bf commit 88055b3
Show file tree
Hide file tree
Showing 53 changed files with 400 additions and 2,640 deletions.
52 changes: 25 additions & 27 deletions sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,31 @@
# license information.
# --------------------------------------------------------------------------
# pylint: disable=too-many-lines,no-self-use

from functools import partial
from io import BytesIO
from typing import ( # pylint: disable=unused-import
Union, Optional, Any, IO, Iterable, AnyStr, Dict, List, Tuple,
TYPE_CHECKING,
TypeVar, Type)
from typing import (
Any, AnyStr, Dict, IO, Iterable, List, Optional, Tuple, Type, TypeVar, Union,
TYPE_CHECKING
)
from urllib.parse import urlparse, quote, unquote
import warnings

try:
from urllib.parse import urlparse, quote, unquote
except ImportError:
from urlparse import urlparse # type: ignore
from urllib2 import quote, unquote # type: ignore
import six
from azure.core.exceptions import ResourceNotFoundError, HttpResponseError, ResourceExistsError
from azure.core.paging import ItemPaged
from azure.core.pipeline import Pipeline
from azure.core.tracing.decorator import distributed_trace
from azure.core.exceptions import ResourceNotFoundError, HttpResponseError, ResourceExistsError

from ._shared import encode_base64
from ._shared.base_client import StorageAccountHostsMixin, parse_connection_str, parse_query, TransportWrapper
from ._shared.encryption import generate_blob_encryption_data
from ._shared.uploads import IterStreamer
from ._shared.request_handlers import (
add_metadata_headers, get_length, read_length,
validate_and_format_range_headers)
from ._shared.response_handlers import return_response_headers, process_storage_error, return_headers_and_deserialized
from ._generated import AzureBlobStorage
from ._generated.models import ( # pylint: disable=unused-import
from ._generated.models import (
DeleteSnapshotsOptionType,
BlobHTTPHeaders,
BlockLookupList,
Expand All @@ -49,22 +45,30 @@
serialize_blob_tags,
serialize_query_format, get_access_conditions
)
from ._deserialize import get_page_ranges_result, deserialize_blob_properties, deserialize_blob_stream, parse_tags, \
from ._deserialize import (
get_page_ranges_result,
deserialize_blob_properties,
deserialize_blob_stream,
parse_tags,
deserialize_pipeline_response_into_cls
)
from ._download import StorageStreamDownloader
from ._encryption import StorageEncryptionMixin
from ._lease import BlobLeaseClient
from ._models import BlobType, BlobBlock, BlobProperties, BlobQueryError, QuickQueryDialect, \
DelimitedJsonDialect, DelimitedTextDialect, PageRangePaged, PageRange
from ._quick_query_helper import BlobQueryReader
from ._upload_helpers import (
upload_block_blob,
upload_append_blob,
upload_page_blob, _any_conditions)
from ._models import BlobType, BlobBlock, BlobProperties, BlobQueryError, QuickQueryDialect, \
DelimitedJsonDialect, DelimitedTextDialect, PageRangePaged, PageRange
from ._download import StorageStreamDownloader
from ._lease import BlobLeaseClient
upload_page_blob,
_any_conditions
)

if TYPE_CHECKING:
from datetime import datetime
from ._generated.models import BlockList
from ._models import ( # pylint: disable=unused-import
from ._models import (
ContentSettings,
ImmutabilityPolicy,
PremiumPageBlobTier,
Expand All @@ -79,7 +83,7 @@
ClassType = TypeVar("ClassType")


class BlobClient(StorageAccountHostsMixin): # pylint: disable=too-many-public-methods
class BlobClient(StorageAccountHostsMixin, StorageEncryptionMixin): # pylint: disable=too-many-public-methods
"""A client to interact with a specific blob, although that blob may not yet exist.
For more optional configuration, please click
Expand Down Expand Up @@ -181,6 +185,7 @@ def __init__(
super(BlobClient, self).__init__(parsed_url, service='blob', credential=credential, **kwargs)
self._client = AzureBlobStorage(self.url, base_url=self.url, pipeline=self._pipeline)
self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access
self.configure_encryption(kwargs)

def _format_url(self, hostname):
container_name = self.container_name
Expand Down Expand Up @@ -359,13 +364,6 @@ def _upload_blob_options( # pylint:disable=too-many-statements
'key': self.key_encryption_key,
'resolver': self.key_resolver_function,
}
if self.key_encryption_key is not None:
cek, iv, encryption_data = generate_blob_encryption_data(
self.key_encryption_key,
self.encryption_version)
encryption_options['cek'] = cek
encryption_options['vector'] = iv
encryption_options['data'] = encryption_data

encoding = kwargs.pop('encoding', 'UTF-8')
if isinstance(data, six.text_type):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,33 @@
import functools
import warnings
from typing import ( # pylint: disable=unused-import
Union, Optional, Any, Iterable, Dict, List,
TYPE_CHECKING,
TypeVar)
Any, Dict, List, Optional, TypeVar, Union,
TYPE_CHECKING
)
from urllib.parse import urlparse


try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse # type: ignore

from azure.core.paging import ItemPaged
from azure.core.exceptions import HttpResponseError
from azure.core.paging import ItemPaged
from azure.core.pipeline import Pipeline
from azure.core.tracing.decorator import distributed_trace

from ._shared.models import LocationMode
from ._shared.base_client import StorageAccountHostsMixin, TransportWrapper, parse_connection_str, parse_query
from ._shared.models import LocationMode
from ._shared.parser import _to_utc_datetime
from ._shared.response_handlers import return_response_headers, process_storage_error, \
from ._shared.response_handlers import (
return_response_headers,
process_storage_error,
parse_to_internal_user_delegation_key
)
from ._generated import AzureBlobStorage
from ._generated.models import StorageServiceProperties, KeyInfo
from ._container_client import ContainerClient
from ._blob_client import BlobClient
from ._models import ContainerPropertiesPaged
from ._deserialize import service_stats_deserialize, service_properties_deserialize
from ._encryption import StorageEncryptionMixin
from ._list_blobs_helper import FilteredBlobPaged
from ._models import ContainerPropertiesPaged
from ._serialize import get_api_version
from ._deserialize import service_stats_deserialize, service_properties_deserialize

if TYPE_CHECKING:
from datetime import datetime
Expand All @@ -55,7 +54,7 @@
ClassType = TypeVar("ClassType")


class BlobServiceClient(StorageAccountHostsMixin):
class BlobServiceClient(StorageAccountHostsMixin, StorageEncryptionMixin):
"""A client to interact with the Blob Service at the account level.
This client provides operations to retrieve and configure the account properties
Expand Down Expand Up @@ -137,6 +136,7 @@ def __init__(
super(BlobServiceClient, self).__init__(parsed_url, service='blob', credential=credential, **kwargs)
self._client = AzureBlobStorage(self.url, base_url=self.url, pipeline=self._pipeline)
self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access
self.configure_encryption(kwargs)

def _format_url(self, hostname):
"""Format the endpoint URL according to the current location
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,53 +7,47 @@

import functools
from typing import ( # pylint: disable=unused-import
Union, Optional, Any, Iterable, AnyStr, Dict, List, Tuple, IO, Iterator,
TYPE_CHECKING,
TypeVar)


try:
from urllib.parse import urlparse, quote, unquote
except ImportError:
from urlparse import urlparse # type: ignore
from urllib2 import quote, unquote # type: ignore
Any, AnyStr, Dict, List, IO, Iterable, Iterator, Optional, TypeVar, Union,
TYPE_CHECKING
)
from urllib.parse import urlparse, quote, unquote

import six

from azure.core import MatchConditions
from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
from azure.core.paging import ItemPaged
from azure.core.tracing.decorator import distributed_trace
from azure.core.pipeline import Pipeline
from azure.core.pipeline.transport import HttpRequest
from azure.core.tracing.decorator import distributed_trace

from ._shared.base_client import StorageAccountHostsMixin, TransportWrapper, parse_connection_str, parse_query
from ._shared.request_handlers import add_metadata_headers, serialize_iso
from ._shared.response_handlers import (
process_storage_error,
return_response_headers,
return_headers_and_deserialized)
return_headers_and_deserialized
)
from ._generated import AzureBlobStorage
from ._generated.models import SignedIdentifier
from ._blob_client import BlobClient
from ._deserialize import deserialize_container_properties
from ._serialize import get_modify_conditions, get_container_cpk_scope_info, get_api_version, get_access_conditions
from ._models import ( # pylint: disable=unused-import
from ._encryption import StorageEncryptionMixin
from ._lease import BlobLeaseClient
from ._list_blobs_helper import BlobPrefix, BlobPropertiesPaged, FilteredBlobPaged
from ._models import (
ContainerProperties,
BlobProperties,
BlobType,
FilteredBlob)
from ._list_blobs_helper import BlobPrefix, BlobPropertiesPaged, FilteredBlobPaged
from ._lease import BlobLeaseClient
from ._blob_client import BlobClient
FilteredBlob
)
from ._serialize import get_modify_conditions, get_container_cpk_scope_info, get_api_version, get_access_conditions

if TYPE_CHECKING:
from azure.core.pipeline.transport import HttpTransport, HttpResponse # pylint: disable=ungrouped-imports
from azure.core.pipeline.policies import HTTPPolicy # pylint: disable=ungrouped-imports
from azure.core.pipeline.transport import HttpResponse # pylint: disable=ungrouped-imports
from datetime import datetime
from ._models import ( # pylint: disable=unused-import
PublicAccess,
AccessPolicy,
ContentSettings,
StandardBlobTier,
PremiumPageBlobTier)

Expand All @@ -73,7 +67,7 @@ def _get_blob_name(blob):
ClassType = TypeVar("ClassType")


class ContainerClient(StorageAccountHostsMixin): # pylint: disable=too-many-public-methods
class ContainerClient(StorageAccountHostsMixin, StorageEncryptionMixin): # pylint: disable=too-many-public-methods
"""A client to interact with a specific container, although that container
may not yet exist.
Expand Down Expand Up @@ -161,6 +155,7 @@ def __init__(
super(ContainerClient, self).__init__(parsed_url, service='blob', credential=credential, **kwargs)
self._client = AzureBlobStorage(self.url, base_url=self.url, pipeline=self._pipeline)
self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access
self.configure_encryption(kwargs)

def _format_url(self, hostname):
container_name = self.container_name
Expand Down
11 changes: 5 additions & 6 deletions sdk/storage/azure-storage-blob/azure/storage/blob/_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,24 @@
import sys
import threading
import time

import warnings
from io import BytesIO
from typing import Iterator, Union

import requests
from azure.core.exceptions import HttpResponseError, ServiceResponseError

from azure.core.tracing.common import with_current_context
from ._shared.encryption import (

from ._shared.request_handlers import validate_and_format_range_headers
from ._shared.response_handlers import process_storage_error, parse_length_from_content_range
from ._deserialize import deserialize_blob_properties, get_page_ranges_result
from ._encryption import (
adjust_blob_size_for_encryption,
decrypt_blob,
get_adjusted_download_range_and_offset,
is_encryption_v2,
parse_encryption_data
)
from ._shared.request_handlers import validate_and_format_range_headers
from ._shared.response_handlers import process_storage_error, parse_length_from_content_range
from ._deserialize import deserialize_blob_properties, get_page_ranges_result


def process_range_and_offset(start_range, end_range, length, encryption_options, encryption_data):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import math
import sys
import warnings
from collections import OrderedDict
from io import BytesIO
from json import (
Expand All @@ -24,8 +25,8 @@

from azure.core.exceptions import HttpResponseError

from .._version import VERSION
from . import encode_base64, decode_base64_to_bytes
from ._version import VERSION
from ._shared import encode_base64, decode_base64_to_bytes


_ENCRYPTION_PROTOCOL_V1 = '1.0'
Expand Down Expand Up @@ -53,6 +54,19 @@ def _validate_key_encryption_key_wrap(kek):
raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'get_key_wrap_algorithm'))


class StorageEncryptionMixin(object):
def configure_encryption(self, kwargs):
self.require_encryption = kwargs.get("require_encryption", False)
self.encryption_version = kwargs.get("encryption_version", "1.0")
self.key_encryption_key = kwargs.get("key_encryption_key")
self.key_resolver_function = kwargs.get("key_resolver_function")
if self.key_encryption_key and self.encryption_version == '1.0':
warnings.warn("This client has been configured to use encryption with version 1.0. " +
"Version 1.0 is deprecated and no longer considered secure. It is highly " +
"recommended that you switch to using version 2.0. The version can be " +
"specified using the 'encryption_version' keyword.")


class _EncryptionAlgorithm(object):
'''
Specifies which client encryption algorithm is used.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# --------------------------------------------------------------------------
import logging
import uuid
import warnings
from typing import ( # pylint: disable=unused-import
Optional,
Any,
Expand Down Expand Up @@ -105,16 +104,6 @@ def __init__(
primary_hostname = (parsed_url.netloc + parsed_url.path).rstrip('/')
self._hosts = {LocationMode.PRIMARY: primary_hostname, LocationMode.SECONDARY: secondary_hostname}

self.require_encryption = kwargs.get("require_encryption", False)
self.encryption_version = kwargs.get("encryption_version", "1.0")
self.key_encryption_key = kwargs.get("key_encryption_key")
self.key_resolver_function = kwargs.get("key_resolver_function")
if self.key_encryption_key and self.encryption_version == '1.0':
warnings.warn("This client has been configured to use encryption with version 1.0. \
Version 1.0 is deprecated and no longer considered secure. It is highly \
recommended that you switch to using version 2.0. The version can be \
specified using the 'encryption_version' keyword.")

self._config, self._pipeline = self._create_pipeline(self.credential, storage_sdk=service, **kwargs)

def __enter__(self):
Expand Down
Loading

0 comments on commit 88055b3

Please sign in to comment.