Skip to content

Commit

Permalink
[Cosmos] AAD authentication async client (#23717)
Browse files Browse the repository at this point in the history
* working authentication to get database account

* working aad authentication for sync client with sample

* readme and changelog

* pylint and better comments on sample

* working async aad

* Delete access_cosmos_with_aad.py

snuck its way into the async PR

* Update _auth_policies.py

* small changes

* Update _cosmos_client_connection.py

* removing changes made in sync

* Update _auth_policy_async.py

* Update _auth_policy_async.py

* Update _auth_policy_async.py

* added licenses to samples
  • Loading branch information
simorenoh authored Apr 6, 2022
1 parent 93a8dda commit c17568f
Show file tree
Hide file tree
Showing 23 changed files with 392 additions and 22 deletions.
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.3.0b4 (Unreleased)

#### Features Added
- Added support for AAD authentication for the async client
- Added support for AAD authentication for the sync client

### 4.3.0b3 (2022-03-10)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@
from ._auth_policy import CosmosBearerTokenCredentialPolicy

ClassType = TypeVar("ClassType")


# pylint: disable=protected-access


Expand Down
175 changes: 175 additions & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
import asyncio
import time

from typing import Any, Awaitable, Optional, Dict, Union
from azure.core.pipeline.policies import AsyncHTTPPolicy
from azure.core.credentials import AccessToken
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.cosmos import http_constants


async def await_result(func, *args, **kwargs):
"""If func returns an awaitable, await it."""
result = func(*args, **kwargs)
if hasattr(result, "__await__"):
# type ignore on await: https://github.com/python/mypy/issues/7587
return await result # type: ignore
return result


class _AsyncCosmosBearerTokenCredentialPolicyBase(object):
"""Base class for a Bearer Token Credential Policy.
:param credential: The credential.
:type credential: ~azure.core.credentials.TokenCredential
:param str scopes: Lets you specify the type of access needed.
"""

def __init__(self, credential, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (TokenCredential, *str, **Any) -> None
super(_AsyncCosmosBearerTokenCredentialPolicyBase, self).__init__()
self._scopes = scopes
self._credential = credential
self._token = None # type: Optional[AccessToken]
self._lock = asyncio.Lock()

@staticmethod
def _enforce_https(request):
# type: (PipelineRequest) -> None

# move 'enforce_https' from options to context so it persists
# across retries but isn't passed to a transport implementation
option = request.context.options.pop("enforce_https", None)

# True is the default setting; we needn't preserve an explicit opt in to the default behavior
if option is False:
request.context["enforce_https"] = option

enforce_https = request.context.get("enforce_https", True)
if enforce_https and not request.http_request.url.lower().startswith("https"):
raise ValueError(
"Bearer token authentication is not permitted for non-TLS protected (non-https) URLs."
)

@staticmethod
def _update_headers(headers, token):
# type: (Dict[str, str], str) -> None
"""Updates the Authorization header with the cosmos signature and bearer token.
This is the main method that differentiates this policy from core's BearerTokenCredentialPolicy and works
to properly sign the authorization header for Cosmos' REST API. For more information:
https://docs.microsoft.com/rest/api/cosmos-db/access-control-on-cosmosdb-resources#authorization-header
:param dict headers: The HTTP Request headers
:param str token: The OAuth token.
"""
headers[http_constants.HttpHeaders.Authorization] = "type=aad&ver=1.0&sig={}".format(token)

@property
def _need_new_token(self) -> bool:
return not self._token or self._token.expires_on - time.time() < 300


class AsyncCosmosBearerTokenCredentialPolicy(_AsyncCosmosBearerTokenCredentialPolicyBase, AsyncHTTPPolicy):
"""Adds a bearer token Authorization header to requests.
:param credential: The credential.
:type credential: ~azure.core.TokenCredential
:param str scopes: Lets you specify the type of access needed.
:raises ValueError: If https_enforce does not match with endpoint being used.
"""

async def on_request(self, request: "PipelineRequest") -> None: # pylint:disable=invalid-overridden-method
"""Adds a bearer token Authorization header to request and sends request to next policy.
:param request: The pipeline request object to be modified.
:type request: ~azure.core.pipeline.PipelineRequest
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""
self._enforce_https(request) # pylint:disable=protected-access

if self._token is None or self._need_new_token:
async with self._lock:
# double check because another coroutine may have acquired a token while we waited to acquire the lock
if self._token is None or self._need_new_token:
self._token = await self._credential.get_token(*self._scopes)
self._update_headers(request.http_request.headers, self._token.token)

async def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs: "Any") -> None:
"""Acquire a token from the credential and authorize the request with it.
Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
authorize future requests.
:param ~azure.core.pipeline.PipelineRequest request: the request
:param str scopes: required scopes of authentication
"""
async with self._lock:
self._token = await self._credential.get_token(*scopes, **kwargs)
self._update_headers(request.http_request.headers, self._token.token)

async def send(self, request: "PipelineRequest") -> "PipelineResponse":
"""Authorize request with a bearer token and send it to the next policy
:param request: The pipeline request object
:type request: ~azure.core.pipeline.PipelineRequest
"""
await await_result(self.on_request, request)
try:
response = await self.next.send(request)
await await_result(self.on_response, request, response)
except Exception: # pylint:disable=broad-except
handled = await await_result(self.on_exception, request)
if not handled:
raise
else:
if response.http_response.status_code == 401:
self._token = None # any cached token is invalid
if "WWW-Authenticate" in response.http_response.headers:
request_authorized = await self.on_challenge(request, response)
if request_authorized:
try:
response = await self.next.send(request)
await await_result(self.on_response, request, response)
except Exception: # pylint:disable=broad-except
handled = await await_result(self.on_exception, request)
if not handled:
raise

return response

async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool:
"""Authorize request according to an authentication challenge
This method is called when the resource provider responds 401 with a WWW-Authenticate header.
:param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
:param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
:returns: a bool indicating whether the policy should send the request
"""
# pylint:disable=unused-argument,no-self-use
return False

def on_response(self, request: PipelineRequest, response: PipelineResponse) -> Union[None, Awaitable[None]]:
"""Executed after the request comes back from the next policy.
:param request: Request to be modified after returning from the policy.
:type request: ~azure.core.pipeline.PipelineRequest
:param response: Pipeline response object
:type response: ~azure.core.pipeline.PipelineResponse
"""

def on_exception(self, request: PipelineRequest) -> None:
"""Executed when an exception is raised while executing the next policy.
This method is executed inside the exception handler.
:param request: The Pipeline request object
:type request: ~azure.core.pipeline.PipelineRequest
"""
# pylint: disable=no-self-use,unused-argument
return
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,12 @@
from .. import _session
from .. import _utils
from ..partition_key import _Undefined, _Empty
from ._auth_policy_async import AsyncCosmosBearerTokenCredentialPolicy

ClassType = TypeVar("ClassType")
# pylint: disable=protected-access


class CosmosClientConnection(object): # pylint: disable=too-many-public-methods,too-many-instance-attributes
"""Represents a document client.
Expand Down Expand Up @@ -113,9 +115,11 @@ def __init__(

self.master_key = None
self.resource_tokens = None
self.aad_credentials = None
if auth is not None:
self.master_key = auth.get("masterKey")
self.resource_tokens = auth.get("resourceTokens")
self.aad_credentials = auth.get("clientSecretCredential")

if auth.get("permissionFeed"):
self.resource_tokens = {}
Expand Down Expand Up @@ -176,12 +180,18 @@ def __init__(

self._user_agent = _utils.get_user_agent_async()

credentials_policy = None
if self.aad_credentials:
scopes = base.create_scope_from_url(self.url_connection)
credentials_policy = AsyncCosmosBearerTokenCredentialPolicy(self.aad_credentials, scopes)

policies = [
HeadersPolicy(**kwargs),
ProxyPolicy(proxies=proxies),
UserAgentPolicy(base_user_agent=self._user_agent, **kwargs),
ContentDecodePolicy(),
retry_policy,
credentials_policy,
CustomHookPolicy(**kwargs),
NetworkTraceLoggingPolicy(**kwargs),
DistributedTracingPolicy(**kwargs),
Expand Down
5 changes: 5 additions & 0 deletions sdk/cosmos/azure-cosmos/samples/access_cosmos_with_aad.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
from azure.cosmos import CosmosClient
import azure.cosmos.exceptions as exceptions
from azure.cosmos.partition_key import PartitionKey
Expand Down
116 changes: 116 additions & 0 deletions sdk/cosmos/azure-cosmos/samples/access_cosmos_with_aad_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
from azure.cosmos.aio import CosmosClient
import azure.cosmos.exceptions as exceptions
from azure.cosmos.partition_key import PartitionKey
from azure.identity.aio import ClientSecretCredential, DefaultAzureCredential
import config
import asyncio

# ----------------------------------------------------------------------------------------------------------
# Prerequistes -
#
# 1. An Azure Cosmos account -
# https://docs.microsoft.com/azure/cosmos-db/create-sql-api-python#create-a-database-account
#
# 2. Microsoft Azure Cosmos
# pip install azure-cosmos>=4.3.0b4
# ----------------------------------------------------------------------------------------------------------
# Sample - demonstrates how to authenticate and use your database account using AAD credentials
# Read more about operations allowed for this authorization method: https://aka.ms/cosmos-native-rbac
# ----------------------------------------------------------------------------------------------------------
# Note:
# This sample creates a Container to your database account.
# Each time a Container is created the account will be billed for 1 hour of usage based on
# the provisioned throughput (RU/s) of that account.
# ----------------------------------------------------------------------------------------------------------
# <configureConnectivity>
HOST = config.settings["host"]
MASTER_KEY = config.settings["master_key"]

TENANT_ID = config.settings["tenant_id"]
CLIENT_ID = config.settings["client_id"]
CLIENT_SECRET = config.settings["client_secret"]

DATABASE_ID = config.settings["database_id"]
CONTAINER_ID = config.settings["container_id"]
PARTITION_KEY = PartitionKey(path="/id")


def get_test_item(num):
test_item = {
'id': 'Item_' + str(num),
'test_object': True,
'lastName': 'Smith'
}
return test_item


async def create_sample_resources():
print("creating sample resources")
async with CosmosClient(HOST, MASTER_KEY) as client:
db = await client.create_database(DATABASE_ID)
await db.create_container(id=CONTAINER_ID, partition_key=PARTITION_KEY)


async def delete_sample_resources():
print("deleting sample resources")
async with CosmosClient(HOST, MASTER_KEY) as client:
await client.delete_database(DATABASE_ID)


async def run_sample():
# Since Azure Cosmos DB data plane SDK does not cover management operations, we have to create our resources
# with a master key authenticated client for this sample.
await create_sample_resources()

# With this done, you can use your AAD service principal id and secret to create your ClientSecretCredential.
# The async ClientSecretCredentials, like the async client, also have a context manager,
# and as such should be used with the `async with` keywords.
async with ClientSecretCredential(
tenant_id=TENANT_ID,
client_id=CLIENT_ID,
client_secret=CLIENT_SECRET) as aad_credentials:

# Use your credentials to authenticate your client.
async with CosmosClient(HOST, aad_credentials) as aad_client:
print("Showed ClientSecretCredential, now showing DefaultAzureCredential")

# You can also utilize DefaultAzureCredential rather than directly passing in the id's and secrets.
# This is the recommended method of authentication, and uses environment variables rather than in-code strings.
async with DefaultAzureCredential() as aad_credentials:

# Use your credentials to authenticate your client.
async with CosmosClient(HOST, aad_credentials) as aad_client:

# Do any R/W data operations with your authorized AAD client.
db = aad_client.get_database_client(DATABASE_ID)
container = db.get_container_client(CONTAINER_ID)

print("Container info: " + str(container.read()))
await container.create_item(get_test_item(879))
print("Point read result: " + str(container.read_item(item='Item_0', partition_key='Item_0')))
query_results = [item async for item in
container.query_items(query='select * from c', partition_key='Item_0')]
assert len(query_results) == 1
print("Query result: " + str(query_results[0]))
await container.delete_item(item='Item_0', partition_key='Item_0')

# Attempting to do management operations will return a 403 Forbidden exception.
try:
await aad_client.delete_database(DATABASE_ID)
except exceptions.CosmosHttpResponseError as e:
assert e.status_code == 403
print("403 error assertion success")

# To clean up the sample, we use a master key client again to get access to deleting containers/ databases.
await delete_sample_resources()
print("end of sample")


if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(run_sample())
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
import azure.cosmos.cosmos_client as cosmos_client
import azure.cosmos.exceptions as exceptions
from azure.cosmos.partition_key import PartitionKey
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
import azure.cosmos.aio.cosmos_client as cosmos_client
import azure.cosmos.exceptions as exceptions
from azure.cosmos.partition_key import PartitionKey
Expand Down
5 changes: 5 additions & 0 deletions sdk/cosmos/azure-cosmos/samples/change_feed_management.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
import azure.cosmos.documents as documents
import azure.cosmos.cosmos_client as cosmos_client
import azure.cosmos.exceptions as exceptions
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
import azure.cosmos.aio.cosmos_client as cosmos_client
import azure.cosmos.exceptions as exceptions
import azure.cosmos.documents as documents
Expand Down
Loading

0 comments on commit c17568f

Please sign in to comment.