Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cosmos] AAD authentication async client #23717

Merged
merged 19 commits into from
Apr 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.3.0b4 (Unreleased)

#### Features Added
- Added support for AAD authentication for the async client
- Added support for AAD authentication for the sync client

### 4.3.0b3 (2022-03-10)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@
from ._auth_policy import CosmosBearerTokenCredentialPolicy

ClassType = TypeVar("ClassType")


# pylint: disable=protected-access


Expand Down
175 changes: 175 additions & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
import asyncio
import time

from typing import Any, Awaitable, Optional, Dict, Union
from azure.core.pipeline.policies import AsyncHTTPPolicy
from azure.core.credentials import AccessToken
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.cosmos import http_constants


async def await_result(func, *args, **kwargs):
"""If func returns an awaitable, await it."""
result = func(*args, **kwargs)
if hasattr(result, "__await__"):
# type ignore on await: https://github.com/python/mypy/issues/7587
return await result # type: ignore
return result


class _AsyncCosmosBearerTokenCredentialPolicyBase(object):
"""Base class for a Bearer Token Credential Policy.

:param credential: The credential.
:type credential: ~azure.core.credentials.TokenCredential
:param str scopes: Lets you specify the type of access needed.
"""

def __init__(self, credential, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (TokenCredential, *str, **Any) -> None
super(_AsyncCosmosBearerTokenCredentialPolicyBase, self).__init__()
self._scopes = scopes
self._credential = credential
self._token = None # type: Optional[AccessToken]
self._lock = asyncio.Lock()

@staticmethod
def _enforce_https(request):
# type: (PipelineRequest) -> None

# move 'enforce_https' from options to context so it persists
# across retries but isn't passed to a transport implementation
option = request.context.options.pop("enforce_https", None)

# True is the default setting; we needn't preserve an explicit opt in to the default behavior
if option is False:
request.context["enforce_https"] = option

enforce_https = request.context.get("enforce_https", True)
if enforce_https and not request.http_request.url.lower().startswith("https"):
raise ValueError(
"Bearer token authentication is not permitted for non-TLS protected (non-https) URLs."
)

@staticmethod
def _update_headers(headers, token):
# type: (Dict[str, str], str) -> None
"""Updates the Authorization header with the cosmos signature and bearer token.
This is the main method that differentiates this policy from core's BearerTokenCredentialPolicy and works
to properly sign the authorization header for Cosmos' REST API. For more information:
https://docs.microsoft.com/rest/api/cosmos-db/access-control-on-cosmosdb-resources#authorization-header

:param dict headers: The HTTP Request headers
:param str token: The OAuth token.
"""
headers[http_constants.HttpHeaders.Authorization] = "type=aad&ver=1.0&sig={}".format(token)

@property
def _need_new_token(self) -> bool:
return not self._token or self._token.expires_on - time.time() < 300


class AsyncCosmosBearerTokenCredentialPolicy(_AsyncCosmosBearerTokenCredentialPolicyBase, AsyncHTTPPolicy):
"""Adds a bearer token Authorization header to requests.

:param credential: The credential.
:type credential: ~azure.core.TokenCredential
:param str scopes: Lets you specify the type of access needed.
:raises ValueError: If https_enforce does not match with endpoint being used.
"""

async def on_request(self, request: "PipelineRequest") -> None: # pylint:disable=invalid-overridden-method
"""Adds a bearer token Authorization header to request and sends request to next policy.

:param request: The pipeline request object to be modified.
:type request: ~azure.core.pipeline.PipelineRequest
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""
self._enforce_https(request) # pylint:disable=protected-access

if self._token is None or self._need_new_token:
async with self._lock:
# double check because another coroutine may have acquired a token while we waited to acquire the lock
if self._token is None or self._need_new_token:
self._token = await self._credential.get_token(*self._scopes)
self._update_headers(request.http_request.headers, self._token.token)

async def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs: "Any") -> None:
"""Acquire a token from the credential and authorize the request with it.

Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
authorize future requests.

:param ~azure.core.pipeline.PipelineRequest request: the request
:param str scopes: required scopes of authentication
"""
async with self._lock:
self._token = await self._credential.get_token(*scopes, **kwargs)
self._update_headers(request.http_request.headers, self._token.token)

async def send(self, request: "PipelineRequest") -> "PipelineResponse":
"""Authorize request with a bearer token and send it to the next policy

:param request: The pipeline request object
:type request: ~azure.core.pipeline.PipelineRequest
"""
await await_result(self.on_request, request)
try:
response = await self.next.send(request)
await await_result(self.on_response, request, response)
except Exception: # pylint:disable=broad-except
handled = await await_result(self.on_exception, request)
if not handled:
raise
else:
if response.http_response.status_code == 401:
self._token = None # any cached token is invalid
if "WWW-Authenticate" in response.http_response.headers:
request_authorized = await self.on_challenge(request, response)
if request_authorized:
try:
response = await self.next.send(request)
await await_result(self.on_response, request, response)
except Exception: # pylint:disable=broad-except
handled = await await_result(self.on_exception, request)
if not handled:
raise

return response

async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool:
"""Authorize request according to an authentication challenge

This method is called when the resource provider responds 401 with a WWW-Authenticate header.

:param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
:param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
:returns: a bool indicating whether the policy should send the request
"""
# pylint:disable=unused-argument,no-self-use
return False

def on_response(self, request: PipelineRequest, response: PipelineResponse) -> Union[None, Awaitable[None]]:
"""Executed after the request comes back from the next policy.

:param request: Request to be modified after returning from the policy.
:type request: ~azure.core.pipeline.PipelineRequest
:param response: Pipeline response object
:type response: ~azure.core.pipeline.PipelineResponse
"""

def on_exception(self, request: PipelineRequest) -> None:
"""Executed when an exception is raised while executing the next policy.

This method is executed inside the exception handler.

:param request: The Pipeline request object
:type request: ~azure.core.pipeline.PipelineRequest
"""
# pylint: disable=no-self-use,unused-argument
return
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,12 @@
from .. import _session
from .. import _utils
from ..partition_key import _Undefined, _Empty
from ._auth_policy_async import AsyncCosmosBearerTokenCredentialPolicy

ClassType = TypeVar("ClassType")
# pylint: disable=protected-access


class CosmosClientConnection(object): # pylint: disable=too-many-public-methods,too-many-instance-attributes
"""Represents a document client.

Expand Down Expand Up @@ -113,9 +115,11 @@ def __init__(

self.master_key = None
self.resource_tokens = None
self.aad_credentials = None
if auth is not None:
self.master_key = auth.get("masterKey")
self.resource_tokens = auth.get("resourceTokens")
self.aad_credentials = auth.get("clientSecretCredential")

if auth.get("permissionFeed"):
self.resource_tokens = {}
Expand Down Expand Up @@ -176,12 +180,18 @@ def __init__(

self._user_agent = _utils.get_user_agent_async()

credentials_policy = None
if self.aad_credentials:
scopes = base.create_scope_from_url(self.url_connection)
credentials_policy = AsyncCosmosBearerTokenCredentialPolicy(self.aad_credentials, scopes)

policies = [
HeadersPolicy(**kwargs),
ProxyPolicy(proxies=proxies),
UserAgentPolicy(base_user_agent=self._user_agent, **kwargs),
ContentDecodePolicy(),
retry_policy,
credentials_policy,
CustomHookPolicy(**kwargs),
NetworkTraceLoggingPolicy(**kwargs),
DistributedTracingPolicy(**kwargs),
Expand Down
5 changes: 5 additions & 0 deletions sdk/cosmos/azure-cosmos/samples/access_cosmos_with_aad.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
from azure.cosmos import CosmosClient
import azure.cosmos.exceptions as exceptions
from azure.cosmos.partition_key import PartitionKey
Expand Down
116 changes: 116 additions & 0 deletions sdk/cosmos/azure-cosmos/samples/access_cosmos_with_aad_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
from azure.cosmos.aio import CosmosClient
simorenoh marked this conversation as resolved.
Show resolved Hide resolved
import azure.cosmos.exceptions as exceptions
from azure.cosmos.partition_key import PartitionKey
from azure.identity.aio import ClientSecretCredential, DefaultAzureCredential
import config
import asyncio

# ----------------------------------------------------------------------------------------------------------
# Prerequistes -
#
# 1. An Azure Cosmos account -
# https://docs.microsoft.com/azure/cosmos-db/create-sql-api-python#create-a-database-account
#
# 2. Microsoft Azure Cosmos
# pip install azure-cosmos>=4.3.0b4
# ----------------------------------------------------------------------------------------------------------
# Sample - demonstrates how to authenticate and use your database account using AAD credentials
# Read more about operations allowed for this authorization method: https://aka.ms/cosmos-native-rbac
# ----------------------------------------------------------------------------------------------------------
# Note:
# This sample creates a Container to your database account.
# Each time a Container is created the account will be billed for 1 hour of usage based on
# the provisioned throughput (RU/s) of that account.
# ----------------------------------------------------------------------------------------------------------
# <configureConnectivity>
HOST = config.settings["host"]
MASTER_KEY = config.settings["master_key"]

TENANT_ID = config.settings["tenant_id"]
CLIENT_ID = config.settings["client_id"]
CLIENT_SECRET = config.settings["client_secret"]

DATABASE_ID = config.settings["database_id"]
CONTAINER_ID = config.settings["container_id"]
PARTITION_KEY = PartitionKey(path="/id")


def get_test_item(num):
test_item = {
'id': 'Item_' + str(num),
'test_object': True,
'lastName': 'Smith'
}
return test_item


async def create_sample_resources():
print("creating sample resources")
async with CosmosClient(HOST, MASTER_KEY) as client:
db = await client.create_database(DATABASE_ID)
await db.create_container(id=CONTAINER_ID, partition_key=PARTITION_KEY)


async def delete_sample_resources():
print("deleting sample resources")
async with CosmosClient(HOST, MASTER_KEY) as client:
await client.delete_database(DATABASE_ID)


async def run_sample():
# Since Azure Cosmos DB data plane SDK does not cover management operations, we have to create our resources
# with a master key authenticated client for this sample.
await create_sample_resources()

# With this done, you can use your AAD service principal id and secret to create your ClientSecretCredential.
# The async ClientSecretCredentials, like the async client, also have a context manager,
# and as such should be used with the `async with` keywords.
async with ClientSecretCredential(
tenant_id=TENANT_ID,
client_id=CLIENT_ID,
client_secret=CLIENT_SECRET) as aad_credentials:

# Use your credentials to authenticate your client.
async with CosmosClient(HOST, aad_credentials) as aad_client:
print("Showed ClientSecretCredential, now showing DefaultAzureCredential")

# You can also utilize DefaultAzureCredential rather than directly passing in the id's and secrets.
# This is the recommended method of authentication, and uses environment variables rather than in-code strings.
async with DefaultAzureCredential() as aad_credentials:

# Use your credentials to authenticate your client.
async with CosmosClient(HOST, aad_credentials) as aad_client:

# Do any R/W data operations with your authorized AAD client.
db = aad_client.get_database_client(DATABASE_ID)
container = db.get_container_client(CONTAINER_ID)

print("Container info: " + str(container.read()))
await container.create_item(get_test_item(879))
print("Point read result: " + str(container.read_item(item='Item_0', partition_key='Item_0')))
query_results = [item async for item in
container.query_items(query='select * from c', partition_key='Item_0')]
assert len(query_results) == 1
print("Query result: " + str(query_results[0]))
await container.delete_item(item='Item_0', partition_key='Item_0')

# Attempting to do management operations will return a 403 Forbidden exception.
try:
await aad_client.delete_database(DATABASE_ID)
except exceptions.CosmosHttpResponseError as e:
assert e.status_code == 403
print("403 error assertion success")

# To clean up the sample, we use a master key client again to get access to deleting containers/ databases.
await delete_sample_resources()
print("end of sample")


if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(run_sample())
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
import azure.cosmos.cosmos_client as cosmos_client
import azure.cosmos.exceptions as exceptions
from azure.cosmos.partition_key import PartitionKey
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
import azure.cosmos.aio.cosmos_client as cosmos_client
import azure.cosmos.exceptions as exceptions
from azure.cosmos.partition_key import PartitionKey
Expand Down
5 changes: 5 additions & 0 deletions sdk/cosmos/azure-cosmos/samples/change_feed_management.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
import azure.cosmos.documents as documents
import azure.cosmos.cosmos_client as cosmos_client
import azure.cosmos.exceptions as exceptions
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
import azure.cosmos.aio.cosmos_client as cosmos_client
import azure.cosmos.exceptions as exceptions
import azure.cosmos.documents as documents
Expand Down
Loading