diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md
index 13530782ad21..1b91b0609950 100644
--- a/sdk/cosmos/azure-cosmos/CHANGELOG.md
+++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md
@@ -3,12 +3,7 @@
### 4.3.0b4 (Unreleased)
#### Features Added
-
-#### Breaking Changes
-
-#### Bugs Fixed
-
-#### Other Changes
+- Added support for AAD authentication for the sync client
### 4.3.0b3 (2022-03-10)
diff --git a/sdk/cosmos/azure-cosmos/README.md b/sdk/cosmos/azure-cosmos/README.md
index 0010d24a8c9b..9930ae0530a5 100644
--- a/sdk/cosmos/azure-cosmos/README.md
+++ b/sdk/cosmos/azure-cosmos/README.md
@@ -76,6 +76,35 @@ KEY = os.environ['ACCOUNT_KEY']
client = CosmosClient(URL, credential=KEY)
```
+### AAD Authentication
+
+You can also authenticate a client utilizing your service principal's AAD credentials and the azure identity package.
+You can directly pass in the credentials information to ClientSecretCrednetial, or use the DefaultAzureCredential:
+```Python
+from azure.cosmos import CosmosClient
+from azure.identity import ClientSecretCredential, DefaultAzureCredential
+
+import os
+url = os.environ['ACCOUNT_URI']
+tenant_id = os.environ['TENANT_ID']
+client_id = os.environ['CLIENT_ID']
+client_secret = os.environ['CLIENT_SECRET']
+
+# Using ClientSecretCredential
+aad_credentials = ClientSecretCredential(
+ tenant_id=tenant_id,
+ client_id=client_id,
+ client_secret=client_secret)
+
+# Using DefaultAzureCredential (recommended)
+aad_credentials = DefaultAzureCredential()
+
+client = CosmosClient(url, aad_credentials)
+```
+Always ensure that the managed identity you use for AAD authentication has `readMetadata` permissions.
+More information on how to set up AAD authentication: [Set up RBAC for AAD authentication](https://docs.microsoft.com/azure/cosmos-db/how-to-setup-rbac)
+More information on allowed operations for AAD authenticated clients: [RBAC Permission Model](https://aka.ms/cosmos-native-rbac)
+
## Key concepts
Once you've initialized a [CosmosClient][ref_cosmosclient], you can interact with the primary resource types in Cosmos DB:
@@ -125,7 +154,7 @@ Currently the features below are **not supported**. For alternatives options, ch
* Change Feed: Processor
* Change Feed: Read multiple partitions key values
* Change Feed: Read specific time
-* Change Feed: Read from the beggining
+* Change Feed: Read from the beginning
* Change Feed: Pull model
* Cross-partition ORDER BY for mixed types
@@ -139,10 +168,6 @@ Currently the features below are **not supported**. For alternatives options, ch
* Get the connection string
* Get the minimum RU/s of a container
-### Security Limitations:
-
-* AAD support
-
## Workarounds
### Bulk processing Limitation Workaround
@@ -153,10 +178,6 @@ If you want to use Python SDK to perform bulk inserts to Cosmos DB, the best alt
Typically, you can use [Azure Portal](https://portal.azure.com/), [Azure Cosmos DB Resource Provider REST API](https://docs.microsoft.com/rest/api/cosmos-db-resource-provider), [Azure CLI](https://docs.microsoft.com/cli/azure/azure-cli-reference-for-cosmos-db) or [PowerShell](https://docs.microsoft.com/azure/cosmos-db/manage-with-powershell) for the control plane unsupported limitations.
-### AAD Support Workaround
-
-A possible workaround is to use managed identities to [programmatically](https://docs.microsoft.com/azure/cosmos-db/managed-identity-based-authentication) get the keys.
-
## Boolean Data Type
While the Python language [uses](https://docs.python.org/3/library/stdtypes.html?highlight=boolean#truth-value-testing) "True" and "False" for boolean types, Cosmos DB [accepts](https://docs.microsoft.com/azure/cosmos-db/sql-query-is-bool) "true" and "false" only. In other words, the Python language uses Boolean values with the first uppercase letter and all other lowercase letters, while Cosmos DB and its SQL language use only lowercase letters for those same Boolean values. How to deal with this challenge?
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py
new file mode 100644
index 000000000000..1210335e37e7
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py
@@ -0,0 +1,166 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See LICENSE.txt in the project root for
+# license information.
+# -------------------------------------------------------------------------
+import time
+
+from typing import Any, Dict, Optional
+from azure.core.credentials import AccessToken
+from azure.core.pipeline import PipelineRequest, PipelineResponse
+from azure.core.pipeline.policies import HTTPPolicy
+from azure.cosmos import http_constants
+
+
+# pylint:disable=too-few-public-methods
+class _CosmosBearerTokenCredentialPolicyBase(object):
+ """Base class for a Bearer Token Credential Policy.
+
+ :param credential: The credential.
+ :type credential: ~azure.core.credentials.TokenCredential
+ :param str scopes: Lets you specify the type of access needed.
+ """
+
+ def __init__(self, credential, *scopes, **kwargs): # pylint:disable=unused-argument
+ # type: (TokenCredential, *str, **Any) -> None
+ super(_CosmosBearerTokenCredentialPolicyBase, self).__init__()
+ self._scopes = scopes
+ self._credential = credential
+ self._token = None # type: Optional[AccessToken]
+
+ @staticmethod
+ def _enforce_https(request):
+ # type: (PipelineRequest) -> None
+
+ # move 'enforce_https' from options to context so it persists
+ # across retries but isn't passed to a transport implementation
+ option = request.context.options.pop("enforce_https", None)
+
+ # True is the default setting; we needn't preserve an explicit opt in to the default behavior
+ if option is False:
+ request.context["enforce_https"] = option
+
+ enforce_https = request.context.get("enforce_https", True)
+ if enforce_https and not request.http_request.url.lower().startswith("https"):
+ raise ValueError(
+ "Bearer token authentication is not permitted for non-TLS protected (non-https) URLs."
+ )
+
+ @staticmethod
+ def _update_headers(headers, token):
+ # type: (Dict[str, str], str) -> None
+ """Updates the Authorization header with the bearer token.
+ This is the main method that differentiates this policy from core's BearerTokenCredentialPolicy and works
+ to properly sign the authorization header for Cosmos' REST API. For more information:
+ https://docs.microsoft.com/rest/api/cosmos-db/access-control-on-cosmosdb-resources#authorization-header
+
+ :param dict headers: The HTTP Request headers
+ :param str token: The OAuth token.
+ """
+ headers[http_constants.HttpHeaders.Authorization] = "type=aad&ver=1.0&sig={}".format(token)
+
+ @property
+ def _need_new_token(self):
+ # type: () -> bool
+ return not self._token or self._token.expires_on - time.time() < 300
+
+
+class CosmosBearerTokenCredentialPolicy(_CosmosBearerTokenCredentialPolicyBase, HTTPPolicy):
+ """Adds a bearer token Authorization header to requests.
+
+ :param credential: The credential.
+ :type credential: ~azure.core.TokenCredential
+ :param str scopes: Lets you specify the type of access needed.
+ :raises ValueError: If https_enforce does not match with endpoint being used.
+ """
+
+ def on_request(self, request):
+ # type: (PipelineRequest) -> None
+ """Called before the policy sends a request.
+
+ The base implementation authorizes the request with a bearer token.
+
+ :param ~azure.core.pipeline.PipelineRequest request: the request
+ """
+ self._enforce_https(request)
+
+ if self._token is None or self._need_new_token:
+ self._token = self._credential.get_token(*self._scopes)
+ self._update_headers(request.http_request.headers, self._token.token)
+
+ def authorize_request(self, request, *scopes, **kwargs):
+ # type: (PipelineRequest, *str, **Any) -> None
+ """Acquire a token from the credential and authorize the request with it.
+
+ Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
+ authorize future requests.
+
+ :param ~azure.core.pipeline.PipelineRequest request: the request
+ :param str scopes: required scopes of authentication
+ """
+ self._token = self._credential.get_token(*scopes, **kwargs)
+ self._update_headers(request.http_request.headers, self._token.token)
+
+ def send(self, request):
+ # type: (PipelineRequest) -> PipelineResponse
+ """Authorize request with a bearer token and send it to the next policy
+
+ :param request: The pipeline request object
+ :type request: ~azure.core.pipeline.PipelineRequest
+ """
+ self.on_request(request)
+ try:
+ response = self.next.send(request)
+ self.on_response(request, response)
+ except Exception: # pylint:disable=broad-except
+ self.on_exception(request)
+ raise
+ else:
+ if response.http_response.status_code == 401:
+ self._token = None # any cached token is invalid
+ if "WWW-Authenticate" in response.http_response.headers:
+ request_authorized = self.on_challenge(request, response)
+ if request_authorized:
+ try:
+ response = self.next.send(request)
+ self.on_response(request, response)
+ except Exception: # pylint:disable=broad-except
+ self.on_exception(request)
+ raise
+
+ return response
+
+ def on_challenge(self, request, response):
+ # type: (PipelineRequest, PipelineResponse) -> bool
+ """Authorize request according to an authentication challenge
+
+ This method is called when the resource provider responds 401 with a WWW-Authenticate header.
+
+ :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
+ :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
+ :returns: a bool indicating whether the policy should send the request
+ """
+ # pylint:disable=unused-argument,no-self-use
+ return False
+
+ def on_response(self, request, response):
+ # type: (PipelineRequest, PipelineResponse) -> None
+ """Executed after the request comes back from the next policy.
+
+ :param request: Request to be modified after returning from the policy.
+ :type request: ~azure.core.pipeline.PipelineRequest
+ :param response: Pipeline response object
+ :type response: ~azure.core.pipeline.PipelineResponse
+ """
+
+ def on_exception(self, request):
+ # type: (PipelineRequest) -> None
+ """Executed when an exception is raised while executing the next policy.
+
+ This method is executed inside the exception handler.
+
+ :param request: The Pipeline request object
+ :type request: ~azure.core.pipeline.PipelineRequest
+ """
+ # pylint: disable=no-self-use,unused-argument
+ return
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py
index f2539bc30319..c9a8af14ecb6 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py
@@ -30,6 +30,7 @@
from typing import Dict, Any
from urllib.parse import quote as urllib_quote
+from urllib.parse import urlsplit
from azure.core import MatchConditions
@@ -663,6 +664,11 @@ def ParsePaths(paths):
return tokens
+def create_scope_from_url(url):
+ parsed_url = urlsplit(url)
+ return parsed_url.scheme + "://" + parsed_url.hostname + "/.default"
+
+
def validate_cache_staleness_value(max_integrated_cache_staleness):
int(max_integrated_cache_staleness) # Will throw error if data type cant be converted to int
if max_integrated_cache_staleness <= 0:
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py
index 0f831405c073..821ddc00adfc 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py
@@ -58,6 +58,7 @@
from . import _session
from . import _utils
from .partition_key import _Undefined, _Empty
+from ._auth_policy import CosmosBearerTokenCredentialPolicy
ClassType = TypeVar("ClassType")
@@ -116,9 +117,11 @@ def __init__(
self.master_key = None
self.resource_tokens = None
+ self.aad_credentials = None
if auth is not None:
self.master_key = auth.get("masterKey")
self.resource_tokens = auth.get("resourceTokens")
+ self.aad_credentials = auth.get("clientSecretCredential")
if auth.get("permissionFeed"):
self.resource_tokens = {}
@@ -176,12 +179,18 @@ def __init__(
self._user_agent = _utils.get_user_agent()
+ credentials_policy = None
+ if self.aad_credentials:
+ scopes = base.create_scope_from_url(self.url_connection)
+ credentials_policy = CosmosBearerTokenCredentialPolicy(self.aad_credentials, scopes)
+
policies = [
HeadersPolicy(**kwargs),
ProxyPolicy(proxies=proxies),
UserAgentPolicy(base_user_agent=self._user_agent, **kwargs),
ContentDecodePolicy(),
retry_policy,
+ credentials_policy,
CustomHookPolicy(**kwargs),
NetworkTraceLoggingPolicy(**kwargs),
DistributedTracingPolicy(**kwargs),
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/auth.py b/sdk/cosmos/azure-cosmos/azure/cosmos/auth.py
index 9c8252bd9869..d3c2cb0d8017 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/auth.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/auth.py
@@ -26,12 +26,11 @@
from hashlib import sha256
import hmac
import urllib.parse
-
from . import http_constants
def GetAuthorizationHeader(
- cosmos_client_connection, verb, path, resource_id_or_fullname, is_name_based, resource_type, headers
+ cosmos_client_connection, verb, path, resource_id_or_fullname, is_name_based, resource_type, headers
):
"""Gets the authorization header.
@@ -51,18 +50,18 @@ def GetAuthorizationHeader(
resource_id_or_fullname = resource_id_or_fullname.lower()
if cosmos_client_connection.master_key:
- return __GetAuthorizationTokenUsingMasterKey(
+ return __get_authorization_token_using_master_key(
verb, resource_id_or_fullname, resource_type, headers, cosmos_client_connection.master_key
)
if cosmos_client_connection.resource_tokens:
- return __GetAuthorizationTokenUsingResourceTokens(
+ return __get_authorization_token_using_resource_token(
cosmos_client_connection.resource_tokens, path, resource_id_or_fullname
)
return None
-def __GetAuthorizationTokenUsingMasterKey(verb, resource_id_or_fullname, resource_type, headers, master_key):
+def __get_authorization_token_using_master_key(verb, resource_id_or_fullname, resource_type, headers, master_key):
"""Gets the authorization token using `master_key.
:param str verb:
@@ -97,7 +96,7 @@ def __GetAuthorizationTokenUsingMasterKey(verb, resource_id_or_fullname, resourc
return "type={type}&ver={ver}&sig={sig}".format(type=master_token, ver=token_version, sig=signature[:-1])
-def __GetAuthorizationTokenUsingResourceTokens(resource_tokens, path, resource_id_or_fullname):
+def __get_authorization_token_using_resource_token(resource_tokens, path, resource_id_or_fullname):
"""Get the authorization token using `resource_tokens`.
:param dict resource_tokens:
@@ -138,7 +137,7 @@ def __GetAuthorizationTokenUsingResourceTokens(resource_tokens, path, resource_i
# Get the last resource id or resource name from the path and get it's token from resource_tokens
for i in range(len(path_parts), 1, -1):
- segment = path_parts[i-1]
+ segment = path_parts[i - 1]
sub_path = "/".join(path_parts[:i])
if not segment in resource_types and sub_path in resource_tokens:
return resource_tokens[sub_path]
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py b/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py
index c581dcdf6baa..f4e08a0e9e81 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py
@@ -60,10 +60,13 @@ def _build_auth(credential):
auth['resourceTokens'] = credential # type: ignore
elif hasattr(credential, '__iter__'):
auth['permissionFeed'] = credential
+ elif hasattr(credential, 'get_token'):
+ auth['clientSecretCredential'] = credential
else:
raise TypeError(
- "Unrecognized credential type. Please supply the master key as str, "
- "or a dictionary or resource tokens, or a list of permissions.")
+ "Unrecognized credential type. Please supply the master key as a string "
+ "or a dictionary, or resource tokens, or a list of permissions, or any instance of a class implementing"
+ " TokenCredential (see azure.identity module for specific implementations such as ClientSecretCredential).")
return auth
diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/http_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/http_constants.py
index 785ac2b728c1..f977b819a97f 100644
--- a/sdk/cosmos/azure-cosmos/azure/cosmos/http_constants.py
+++ b/sdk/cosmos/azure-cosmos/azure/cosmos/http_constants.py
@@ -379,6 +379,7 @@ class SubStatusCodes(object):
REDUNDANT_COLLECTION_PUT = 1009
SHARED_THROUGHPUT_DATABASE_QUOTA_EXCEEDED = 1010
SHARED_THROUGHPUT_OFFER_GROW_NOT_NEEDED = 1011
+ AAD_REQUEST_NOT_AUTHORIZED = 5300
# 404: LSN in session token is higher
READ_SESSION_NOTAVAILABLE = 1002
diff --git a/sdk/cosmos/azure-cosmos/dev_requirements.txt b/sdk/cosmos/azure-cosmos/dev_requirements.txt
index 29e4ad8aefb6..e6fcb54061af 100644
--- a/sdk/cosmos/azure-cosmos/dev_requirements.txt
+++ b/sdk/cosmos/azure-cosmos/dev_requirements.txt
@@ -1,3 +1,4 @@
azure-core
+azure-identity
-e ../../../tools/azure-sdk-tools
-e ../../../tools/azure-devtools
\ No newline at end of file
diff --git a/sdk/cosmos/azure-cosmos/samples/access_cosmos_with_aad.py b/sdk/cosmos/azure-cosmos/samples/access_cosmos_with_aad.py
new file mode 100644
index 000000000000..33f080fce9f5
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos/samples/access_cosmos_with_aad.py
@@ -0,0 +1,102 @@
+from azure.cosmos import CosmosClient
+import azure.cosmos.exceptions as exceptions
+from azure.cosmos.partition_key import PartitionKey
+from azure.identity import ClientSecretCredential, DefaultAzureCredential
+import config
+
+# ----------------------------------------------------------------------------------------------------------
+# Prerequistes -
+#
+# 1. An Azure Cosmos account -
+# https://docs.microsoft.com/azure/cosmos-db/create-sql-api-python#create-a-database-account
+#
+# 2. Microsoft Azure Cosmos
+# pip install azure-cosmos>=4.3.0b4
+# ----------------------------------------------------------------------------------------------------------
+# Sample - demonstrates how to authenticate and use your database account using AAD credentials
+# Read more about operations allowed for this authorization method: https://aka.ms/cosmos-native-rbac
+# ----------------------------------------------------------------------------------------------------------
+# Note:
+# This sample creates a Container to your database account.
+# Each time a Container is created the account will be billed for 1 hour of usage based on
+# the provisioned throughput (RU/s) of that account.
+# ----------------------------------------------------------------------------------------------------------
+#
+HOST = config.settings["host"]
+MASTER_KEY = config.settings["master_key"]
+
+TENANT_ID = config.settings["tenant_id"]
+CLIENT_ID = config.settings["client_id"]
+CLIENT_SECRET = config.settings["client_secret"]
+
+DATABASE_ID = config.settings["database_id"]
+CONTAINER_ID = config.settings["container_id"]
+PARTITION_KEY = PartitionKey(path="/id")
+
+
+def get_test_item(num):
+ test_item = {
+ 'id': 'Item_' + str(num),
+ 'test_object': True,
+ 'lastName': 'Smith'
+ }
+ return test_item
+
+
+def create_sample_resources():
+ print("creating sample resources")
+ client = CosmosClient(HOST, MASTER_KEY)
+ db = client.create_database(DATABASE_ID)
+ db.create_container(id=CONTAINER_ID, partition_key=PARTITION_KEY)
+
+
+def delete_sample_resources():
+ print("deleting sample resources")
+ client = CosmosClient(HOST, MASTER_KEY)
+ client.delete_database(DATABASE_ID)
+
+
+def run_sample():
+ # Since Azure Cosmos DB data plane SDK does not cover management operations, we have to create our resources
+ # with a master key authenticated client for this sample.
+ create_sample_resources()
+
+ # With this done, you can use your AAD service principal id and secret to create your ClientSecretCredential.
+ aad_credentials = ClientSecretCredential(
+ tenant_id=TENANT_ID,
+ client_id=CLIENT_ID,
+ client_secret=CLIENT_SECRET)
+
+ # You can also utilize DefaultAzureCredential rather than directly passing in the id's and secrets.
+ # This is the recommended method of authentication, and uses environment variables rather than in-code strings.
+ aad_credentials = DefaultAzureCredential()
+
+ # Use your credentials to authenticate your client.
+ aad_client = CosmosClient(HOST, aad_credentials)
+
+ # Do any R/W data operations with your authorized AAD client.
+ db = aad_client.get_database_client(DATABASE_ID)
+ container = db.get_container_client(CONTAINER_ID)
+
+ print("Container info: " + str(container.read()))
+ container.create_item(get_test_item(0))
+ print("Point read result: " + str(container.read_item(item='Item_0', partition_key='Item_0')))
+ query_results = list(container.query_items(query='select * from c', partition_key='Item_0'))
+ assert len(query_results) == 1
+ print("Query result: " + str(query_results[0]))
+ container.delete_item(item='Item_0', partition_key='Item_0')
+
+ # Attempting to do management operations will return a 403 Forbidden exception.
+ try:
+ aad_client.delete_database(DATABASE_ID)
+ except exceptions.CosmosHttpResponseError as e:
+ assert e.status_code == 403
+ print("403 error assertion success")
+
+ # To clean up the sample, we use a master key client again to get access to deleting containers and databases.
+ delete_sample_resources()
+ print("end of sample")
+
+
+if __name__ == "__main__":
+ run_sample()
diff --git a/sdk/cosmos/azure-cosmos/samples/config.py b/sdk/cosmos/azure-cosmos/samples/config.py
index a85ac445a84c..a69ee67c0e88 100644
--- a/sdk/cosmos/azure-cosmos/samples/config.py
+++ b/sdk/cosmos/azure-cosmos/samples/config.py
@@ -5,4 +5,7 @@
'master_key': os.environ.get('ACCOUNT_KEY', '[YOUR KEY]'),
'database_id': os.environ.get('COSMOS_DATABASE', '[YOUR DATABASE]'),
'container_id': os.environ.get('COSMOS_CONTAINER', '[YOUR CONTAINER]'),
+ 'tenant_id': os.environ.get('TENANT_ID', '[YOUR TENANT ID]'),
+ 'client_id': os.environ.get('CLIENT_ID', '[YOUR CLIENT ID]'),
+ 'client_secret': os.environ.get('CLIENT_SECRET', '[YOUR CLIENT SECRET]'),
}
diff --git a/sdk/cosmos/azure-cosmos/test/test_aad.py b/sdk/cosmos/azure-cosmos/test/test_aad.py
new file mode 100644
index 000000000000..35c620c24c40
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos/test/test_aad.py
@@ -0,0 +1,158 @@
+# The MIT License (MIT)
+# Copyright (c) 2022 Microsoft Corporation
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+import unittest
+
+import pytest
+import base64
+import time
+import json
+from io import StringIO
+
+import azure.cosmos.cosmos_client as cosmos_client
+from azure.cosmos import exceptions
+from azure.identity import ClientSecretCredential
+from azure.core import exceptions
+from azure.core.credentials import AccessToken
+import test_config
+
+pytestmark = pytest.mark.cosmosEmulator
+
+
+def _remove_padding(encoded_string):
+ while encoded_string.endswith("="):
+ encoded_string = encoded_string[0:len(encoded_string) - 1]
+
+ return encoded_string
+
+
+def get_test_item(num):
+ test_item = {
+ 'id': 'Item_' + str(num),
+ 'test_object': True,
+ 'lastName': 'Smith'
+ }
+ return test_item
+
+
+class CosmosEmulatorCredential(object):
+
+ def get_token(self, *scopes, **kwargs):
+ # type: (*str, **Any) -> AccessToken
+ """Request an access token for the emulator. Based on Azure Core's Access Token Credential.
+
+ This method is called automatically by Azure SDK clients.
+
+ :param str scopes: desired scopes for the access token. This method requires at least one scope.
+ :rtype: :class:`azure.core.credentials.AccessToken`
+ :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks
+ required data, state, or platform support
+ :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message``
+ attribute gives a reason.
+ """
+ aad_header_cosmos_emulator = "{\"typ\":\"JWT\",\"alg\":\"RS256\",\"x5t\":\"" \
+ "CosmosEmulatorPrimaryMaster\",\"kid\":\"CosmosEmulatorPrimaryMaster\"}"
+
+ aad_claim_cosmos_emulator_format = {"aud": "https://localhost.localhost",
+ "iss": "https://sts.fake-issuer.net/7b1999a1-dfd7-440e-8204-00170979b984",
+ "iat": int(time.time()), "nbf": int(time.time()),
+ "exp": int(time.time() + 7200), "aio": "", "appid": "localhost",
+ "appidacr": "1", "idp": "https://localhost:8081/",
+ "oid": "96313034-4739-43cb-93cd-74193adbe5b6", "rh": "", "sub": "localhost",
+ "tid": "EmulatorFederation", "uti": "", "ver": "1.0",
+ "scp": "user_impersonation",
+ "groups": ["7ce1d003-4cb3-4879-b7c5-74062a35c66e",
+ "e99ff30c-c229-4c67-ab29-30a6aebc3e58",
+ "5549bb62-c77b-4305-bda9-9ec66b85d9e4",
+ "c44fd685-5c58-452c-aaf7-13ce75184f65",
+ "be895215-eab5-43b7-9536-9ef8fe130330"]}
+
+ emulator_key = test_config._test_config.masterKey
+
+ first_encoded_bytes = base64.urlsafe_b64encode(aad_header_cosmos_emulator.encode("utf-8"))
+ first_encoded_padded = str(first_encoded_bytes, "utf-8")
+ first_encoded = _remove_padding(first_encoded_padded)
+
+ str_io_obj = StringIO()
+ json.dump(aad_claim_cosmos_emulator_format, str_io_obj)
+ aad_claim_cosmos_emulator_format_string = str(str_io_obj.getvalue()).replace(" ", "")
+ second = aad_claim_cosmos_emulator_format_string
+ second_encoded_bytes = base64.urlsafe_b64encode(second.encode("utf-8"))
+ second_encoded_padded = str(second_encoded_bytes, "utf-8")
+ second_encoded = _remove_padding(second_encoded_padded)
+
+ emulator_key_encoded_bytes = base64.urlsafe_b64encode(emulator_key.encode("utf-8"))
+ emulator_key_encoded_padded = str(emulator_key_encoded_bytes, "utf-8")
+ emulator_key_encoded = _remove_padding(emulator_key_encoded_padded)
+
+ return AccessToken(first_encoded + "." + second_encoded + "." + emulator_key_encoded, int(time.time() + 7200))
+
+
+@pytest.mark.usefixtures("teardown")
+class AadTest(unittest.TestCase):
+ configs = test_config._test_config
+ host = configs.host
+ masterKey = configs.masterKey
+
+ @classmethod
+ def setUpClass(cls):
+ cls.client = cosmos_client.CosmosClient(cls.host, cls.masterKey)
+ cls.database = test_config._test_config.create_database_if_not_exist(cls.client)
+ cls.container = test_config._test_config.create_collection_if_not_exist_no_custom_throughput(cls.client)
+
+ def test_wrong_credentials(self):
+ wrong_aad_credentials = ClientSecretCredential(
+ "wrong_tenant_id",
+ "wrong_client_id",
+ "wrong_client_secret")
+
+ try:
+ cosmos_client.CosmosClient(self.host, wrong_aad_credentials)
+ except exceptions.ClientAuthenticationError as e:
+ print("Client successfully failed to authenticate with message: {}".format(e.message))
+
+ def test_emulator_aad_credentials(self):
+ if self.host != 'https://localhost:8081/':
+ print("This test is only configured to run on the emulator, skipping now.")
+ return
+
+ aad_client = cosmos_client.CosmosClient(self.host, CosmosEmulatorCredential())
+ # Do any R/W data operations with your authorized AAD client
+ db = aad_client.get_database_client(self.configs.TEST_DATABASE_ID)
+ container = db.get_container_client(self.configs.TEST_COLLECTION_SINGLE_PARTITION_ID)
+
+ print("Container info: " + str(container.read()))
+ container.create_item(get_test_item(0))
+ print("Point read result: " + str(container.read_item(item='Item_0', partition_key='Item_0')))
+ query_results = list(container.query_items(query='select * from c', partition_key='Item_0'))
+ assert len(query_results) == 1
+ print("Query result: " + str(query_results[0]))
+ container.delete_item(item='Item_0', partition_key='Item_0')
+
+ # Attempting to do management operations will return a 403 Forbidden exception
+ try:
+ aad_client.delete_database(self.configs.TEST_DATABASE_ID)
+ except exceptions.CosmosHttpResponseError as e:
+ assert e.status_code == 403
+ print("403 error assertion success")
+
+
+if __name__ == "__main__":
+ unittest.main()