Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cosmos] AAD authentication async client #23717

Merged
merged 19 commits into from
Apr 6, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,7 @@
### 4.3.0b4 (Unreleased)

#### Features Added

#### Breaking Changes

#### Bugs Fixed

#### Other Changes
- Added support for AAD authentication for the async client

### 4.3.0b3 (2022-03-10)

Expand Down
10 changes: 1 addition & 9 deletions sdk/cosmos/azure-cosmos/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ Currently the features below are **not supported**. For alternatives options, ch
* Change Feed: Processor
* Change Feed: Read multiple partitions key values
* Change Feed: Read specific time
* Change Feed: Read from the beggining
* Change Feed: Read from the beginning
* Change Feed: Pull model
* Cross-partition ORDER BY for mixed types

Expand All @@ -139,10 +139,6 @@ Currently the features below are **not supported**. For alternatives options, ch
* Get the connection string
* Get the minimum RU/s of a container

### Security Limitations:

* AAD support

## Workarounds

### Bulk processing Limitation Workaround
Expand All @@ -153,10 +149,6 @@ If you want to use Python SDK to perform bulk inserts to Cosmos DB, the best alt

Typically, you can use [Azure Portal](https://portal.azure.com/), [Azure Cosmos DB Resource Provider REST API](https://docs.microsoft.com/rest/api/cosmos-db-resource-provider), [Azure CLI](https://docs.microsoft.com/cli/azure/azure-cli-reference-for-cosmos-db) or [PowerShell](https://docs.microsoft.com/azure/cosmos-db/manage-with-powershell) for the control plane unsupported limitations.

### AAD Support Workaround

A possible workaround is to use managed identities to [programmatically](https://docs.microsoft.com/azure/cosmos-db/managed-identity-based-authentication) get the keys.

## Boolean Data Type

While the Python language [uses](https://docs.python.org/3/library/stdtypes.html?highlight=boolean#truth-value-testing) "True" and "False" for boolean types, Cosmos DB [accepts](https://docs.microsoft.com/azure/cosmos-db/sql-query-is-bool) "true" and "false" only. In other words, the Python language uses Boolean values with the first uppercase letter and all other lowercase letters, while Cosmos DB and its SQL language use only lowercase letters for those same Boolean values. How to deal with this challenge?
Expand Down
6 changes: 6 additions & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from typing import Dict, Any

from urllib.parse import quote as urllib_quote
from urllib.parse import urlsplit

from azure.core import MatchConditions

Expand Down Expand Up @@ -663,6 +664,11 @@ def ParsePaths(paths):
return tokens


def create_scope_from_url(url):
parsed_url = urlsplit(url)
return parsed_url.scheme + "://" + parsed_url.hostname + "/.default"


def validate_cache_staleness_value(max_integrated_cache_staleness):
int(max_integrated_cache_staleness) # Will throw error if data type cant be converted to int
if max_integrated_cache_staleness <= 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@
from .partition_key import _Undefined, _Empty

ClassType = TypeVar("ClassType")


# pylint: disable=protected-access


Expand Down
175 changes: 175 additions & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
import asyncio
import time

from typing import Any, Awaitable, Optional, Dict, Union
from azure.core.pipeline.policies import AsyncHTTPPolicy
from azure.core.credentials import AccessToken
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.cosmos import http_constants


async def await_result(func, *args, **kwargs):
"""If func returns an awaitable, await it."""
result = func(*args, **kwargs)
if hasattr(result, "__await__"):
# type ignore on await: https://github.com/python/mypy/issues/7587
return await result # type: ignore
return result


class _AsyncCosmosBearerTokenCredentialPolicyBase(object):
"""Base class for a Bearer Token Credential Policy.
:param credential: The credential.
:type credential: ~azure.core.credentials.TokenCredential
:param str scopes: Lets you specify the type of access needed.
"""

def __init__(self, credential, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (TokenCredential, *str, **Any) -> None
super(_AsyncCosmosBearerTokenCredentialPolicyBase, self).__init__()
self._scopes = scopes
self._credential = credential
self._token = None # type: Optional[AccessToken]
self._lock = asyncio.Lock()

@staticmethod
def _enforce_https(request):
# type: (PipelineRequest) -> None

# move 'enforce_https' from options to context so it persists
# across retries but isn't passed to a transport implementation
option = request.context.options.pop("enforce_https", None)

# True is the default setting; we needn't preserve an explicit opt in to the default behavior
if option is False:
request.context["enforce_https"] = option

enforce_https = request.context.get("enforce_https", True)
if enforce_https and not request.http_request.url.lower().startswith("https"):
raise ValueError(
"Bearer token authentication is not permitted for non-TLS protected (non-https) URLs."
)

@staticmethod
def _update_headers(headers, token):
# type: (Dict[str, str], str) -> None
"""Updates the Authorization header with the cosmos signature and bearer token.
This is the main method that differentiates this policy from core's BearerTokenCredentialPolicy and works
to properly sign the authorization header for Cosmos' REST API. For more information:
https://docs.microsoft.com/rest/api/cosmos-db/access-control-on-cosmosdb-resources#authorization-header
:param dict headers: The HTTP Request headers
:param str token: The OAuth token.
"""
headers[http_constants.HttpHeaders.Authorization] = "type=aad&ver=1.0&sig={}".format(token)

@property
def _need_new_token(self) -> bool:
return not self._token or self._token.expires_on - time.time() < 300


class AsyncCosmosBearerTokenCredentialPolicy(_AsyncCosmosBearerTokenCredentialPolicyBase, AsyncHTTPPolicy):
"""Adds a bearer token Authorization header to requests.
:param credential: The credential.
:type credential: ~azure.core.TokenCredential
:param str scopes: Lets you specify the type of access needed.
:raises ValueError: If https_enforce does not match with endpoint being used.
"""

async def on_request(self, request: "PipelineRequest") -> None: # pylint:disable=invalid-overridden-method
"""Adds a bearer token Authorization header to request and sends request to next policy.
:param request: The pipeline request object to be modified.
:type request: ~azure.core.pipeline.PipelineRequest
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""
self._enforce_https(request) # pylint:disable=protected-access

if self._token is None or self._need_new_token:
async with self._lock:
# double check because another coroutine may have acquired a token while we waited to acquire the lock
if self._token is None or self._need_new_token:
self._token = await self._credential.get_token(*self._scopes)
self._update_headers(request.http_request.headers, self._token.token)

async def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs: "Any") -> None:
"""Acquire a token from the credential and authorize the request with it.
Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
authorize future requests.
:param ~azure.core.pipeline.PipelineRequest request: the request
:param str scopes: required scopes of authentication
"""
async with self._lock:
self._token = await self._credential.get_token(*scopes, **kwargs)
self._update_headers(request.http_request.headers, self._token.token)

async def send(self, request: "PipelineRequest") -> "PipelineResponse":
"""Authorize request with a bearer token and send it to the next policy
:param request: The pipeline request object
:type request: ~azure.core.pipeline.PipelineRequest
"""
await await_result(self.on_request, request)
try:
response = await self.next.send(request)
await await_result(self.on_response, request, response)
except Exception: # pylint:disable=broad-except
handled = await await_result(self.on_exception, request)
if not handled:
raise
else:
if response.http_response.status_code == 401:
self._token = None # any cached token is invalid
if "WWW-Authenticate" in response.http_response.headers:
request_authorized = await self.on_challenge(request, response)
if request_authorized:
try:
response = await self.next.send(request)
await await_result(self.on_response, request, response)
except Exception: # pylint:disable=broad-except
handled = await await_result(self.on_exception, request)
if not handled:
raise

return response

async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool:
"""Authorize request according to an authentication challenge
This method is called when the resource provider responds 401 with a WWW-Authenticate header.
:param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
:param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
:returns: a bool indicating whether the policy should send the request
"""
# pylint:disable=unused-argument,no-self-use
return False

def on_response(self, request: PipelineRequest, response: PipelineResponse) -> Union[None, Awaitable[None]]:
"""Executed after the request comes back from the next policy.
:param request: Request to be modified after returning from the policy.
:type request: ~azure.core.pipeline.PipelineRequest
:param response: Pipeline response object
:type response: ~azure.core.pipeline.PipelineResponse
"""

def on_exception(self, request: PipelineRequest) -> None:
"""Executed when an exception is raised while executing the next policy.
This method is executed inside the exception handler.
:param request: The Pipeline request object
:type request: ~azure.core.pipeline.PipelineRequest
"""
# pylint: disable=no-self-use,unused-argument
return
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,12 @@
from .. import _session
from .. import _utils
from ..partition_key import _Undefined, _Empty
from ._auth_policy_async import AsyncCosmosBearerTokenCredentialPolicy

ClassType = TypeVar("ClassType")
# pylint: disable=protected-access


class CosmosClientConnection(object): # pylint: disable=too-many-public-methods,too-many-instance-attributes
"""Represents a document client.
Expand Down Expand Up @@ -113,9 +115,11 @@ def __init__(

self.master_key = None
self.resource_tokens = None
self.aad_credentials = None
if auth is not None:
self.master_key = auth.get("masterKey")
self.resource_tokens = auth.get("resourceTokens")
self.aad_credentials = auth.get("clientSecretCredential")

if auth.get("permissionFeed"):
self.resource_tokens = {}
Expand Down Expand Up @@ -176,12 +180,18 @@ def __init__(

self._user_agent = _utils.get_user_agent_async()

credentials_policy = None
if self.aad_credentials:
scopes = base.create_scope_from_url(self.url_connection)
credentials_policy = AsyncCosmosBearerTokenCredentialPolicy(self.aad_credentials, scopes)

policies = [
HeadersPolicy(**kwargs),
ProxyPolicy(proxies=proxies),
UserAgentPolicy(base_user_agent=self._user_agent, **kwargs),
ContentDecodePolicy(),
retry_policy,
credentials_policy,
CustomHookPolicy(**kwargs),
NetworkTraceLoggingPolicy(**kwargs),
DistributedTracingPolicy(**kwargs),
Expand Down
13 changes: 6 additions & 7 deletions sdk/cosmos/azure-cosmos/azure/cosmos/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@
from hashlib import sha256
import hmac
import urllib.parse

from . import http_constants


def GetAuthorizationHeader(
cosmos_client_connection, verb, path, resource_id_or_fullname, is_name_based, resource_type, headers
cosmos_client_connection, verb, path, resource_id_or_fullname, is_name_based, resource_type, headers
):
"""Gets the authorization header.

Expand All @@ -51,18 +50,18 @@ def GetAuthorizationHeader(
resource_id_or_fullname = resource_id_or_fullname.lower()

if cosmos_client_connection.master_key:
return __GetAuthorizationTokenUsingMasterKey(
return __get_authorization_token_using_master_key(
verb, resource_id_or_fullname, resource_type, headers, cosmos_client_connection.master_key
)
if cosmos_client_connection.resource_tokens:
return __GetAuthorizationTokenUsingResourceTokens(
return __get_authorization_token_using_resource_token(
cosmos_client_connection.resource_tokens, path, resource_id_or_fullname
)

return None


def __GetAuthorizationTokenUsingMasterKey(verb, resource_id_or_fullname, resource_type, headers, master_key):
def __get_authorization_token_using_master_key(verb, resource_id_or_fullname, resource_type, headers, master_key):
"""Gets the authorization token using `master_key.

:param str verb:
Expand Down Expand Up @@ -97,7 +96,7 @@ def __GetAuthorizationTokenUsingMasterKey(verb, resource_id_or_fullname, resourc
return "type={type}&ver={ver}&sig={sig}".format(type=master_token, ver=token_version, sig=signature[:-1])


def __GetAuthorizationTokenUsingResourceTokens(resource_tokens, path, resource_id_or_fullname):
def __get_authorization_token_using_resource_token(resource_tokens, path, resource_id_or_fullname):
"""Get the authorization token using `resource_tokens`.

:param dict resource_tokens:
Expand Down Expand Up @@ -138,7 +137,7 @@ def __GetAuthorizationTokenUsingResourceTokens(resource_tokens, path, resource_i

# Get the last resource id or resource name from the path and get it's token from resource_tokens
for i in range(len(path_parts), 1, -1):
segment = path_parts[i-1]
segment = path_parts[i - 1]
sub_path = "/".join(path_parts[:i])
if not segment in resource_types and sub_path in resource_tokens:
return resource_tokens[sub_path]
Expand Down
6 changes: 4 additions & 2 deletions sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ def _build_auth(credential):
auth['resourceTokens'] = credential # type: ignore
elif hasattr(credential, '__iter__'):
auth['permissionFeed'] = credential
elif hasattr(credential, 'get_token'):
auth['clientSecretCredential'] = credential
else:
raise TypeError(
"Unrecognized credential type. Please supply the master key as str, "
"or a dictionary or resource tokens, or a list of permissions.")
"Unrecognized credential type. Please supply the master key as a string "
"or a dictionary, or resource tokens, or a list of permissions, or a ClientSecretCredential.")
return auth


Expand Down
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/http_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ class SubStatusCodes(object):
REDUNDANT_COLLECTION_PUT = 1009
SHARED_THROUGHPUT_DATABASE_QUOTA_EXCEEDED = 1010
SHARED_THROUGHPUT_OFFER_GROW_NOT_NEEDED = 1011
AAD_REQUEST_NOT_AUTHORIZED = 5300

# 404: LSN in session token is higher
READ_SESSION_NOTAVAILABLE = 1002
Expand Down
Loading