Skip to content

Commit

Permalink
[Identity] Enable CAE toggle per token request
Browse files Browse the repository at this point in the history
Signed-off-by: Paul Van Eck <[email protected]>
  • Loading branch information
pvaneck committed Jun 20, 2023
1 parent dc395e5 commit e31f36f
Show file tree
Hide file tree
Showing 23 changed files with 198 additions and 75 deletions.
2 changes: 2 additions & 0 deletions sdk/core/azure-core/azure/core/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def get_token(
:keyword str claims: Additional claims required in the token, such as those returned in a resource
provider's claims challenge following an authorization failure.
:keyword str tenant_id: Optional tenant to include in the token request.
:keyword str enable_cae: Enables configuring "CP1" client capabilities on all token requests to support
Continuous Access Evaluation (CAE). Defaults to False.
:rtype: AccessToken
:return: An AccessToken instance containing the token string and its expiration time in Unix time.
Expand Down
2 changes: 2 additions & 0 deletions sdk/core/azure-core/azure/core/credentials_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ async def get_token(
:keyword str claims: Additional claims required in the token, such as those returned in a resource
provider's claims challenge following an authorization failure.
:keyword str tenant_id: Optional tenant to include in the token request.
:keyword str enable_cae: Enables configuring "CP1" client capabilities on all token requests to support
Continuous Access Evaluation (CAE). Defaults to False.
:rtype: AccessToken
:return: An AccessToken instance containing the token string and its expiration time in Unix time.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@ class _BearerTokenCredentialPolicyBase:
:param credential: The credential.
:type credential: ~azure.core.credentials.TokenCredential
:param str scopes: Lets you specify the type of access needed.
:keyword str enable_cae: Enables configuring "CP1" client capabilities on all token requests to support
Continuous Access Evaluation (CAE). Defaults to False.
"""

def __init__(self, credential: "TokenCredential", *scopes: str, **kwargs) -> None: # pylint:disable=unused-argument
super(_BearerTokenCredentialPolicyBase, self).__init__()
self._scopes = scopes
self._credential = credential
self._token: Optional["AccessToken"] = None
self._enable_cae: bool = kwargs.get("enable_cae", False)

@staticmethod
def _enforce_https(request: "PipelineRequest") -> None:
Expand Down Expand Up @@ -74,6 +77,8 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy):
:param credential: The credential.
:type credential: ~azure.core.TokenCredential
:param str scopes: Lets you specify the type of access needed.
:keyword str enable_cae: Enables configuring "CP1" client capabilities on all token requests to support
Continuous Access Evaluation (CAE). Defaults to False.
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""

Expand All @@ -87,7 +92,7 @@ def on_request(self, request: "PipelineRequest") -> None:
self._enforce_https(request)

if self._token is None or self._need_new_token:
self._token = self._credential.get_token(*self._scopes)
self._token = self._credential.get_token(*self._scopes, enable_cae=self._enable_cae)
self._update_headers(request.http_request.headers, self._token.token)

def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs) -> None:
Expand All @@ -99,6 +104,7 @@ def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs)
:param ~azure.core.pipeline.PipelineRequest request: the request
:param str scopes: required scopes of authentication
"""
kwargs.setdefault("enable_cae", self._enable_cae)
self._token = self._credential.get_token(*scopes, **kwargs)
self._update_headers(request.http_request.headers, self._token.token)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy):
:param credential: The credential.
:type credential: ~azure.core.credentials.TokenCredential
:param str scopes: Lets you specify the type of access needed.
:keyword str enable_cae: Enables configuring "CP1" client capabilities on all token requests to support
Continuous Access Evaluation (CAE). Defaults to False.
"""

def __init__(self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: Any) -> None:
Expand All @@ -35,6 +37,7 @@ def __init__(self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: A
self._lock = asyncio.Lock()
self._scopes = scopes
self._token: Optional["AccessToken"] = None
self._enable_cae: bool = kwargs.get("enable_cae", False)

async def on_request(self, request: "PipelineRequest") -> None: # pylint:disable=invalid-overridden-method
"""Adds a bearer token Authorization header to request and sends request to next policy.
Expand All @@ -49,7 +52,7 @@ async def on_request(self, request: "PipelineRequest") -> None: # pylint:disabl
async with self._lock:
# double check because another coroutine may have acquired a token while we waited to acquire the lock
if self._token is None or self._need_new_token():
self._token = await await_result(self._credential.get_token, *self._scopes)
self._token = await await_result(self._credential.get_token, *self._scopes, enable_cae=self._enable_cae)
request.http_request.headers["Authorization"] = "Bearer " + cast(AccessToken, self._token).token

async def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs: Any) -> None:
Expand All @@ -61,6 +64,7 @@ async def authorize_request(self, request: "PipelineRequest", *scopes: str, **kw
:param ~azure.core.pipeline.PipelineRequest request: the request
:param str scopes: required scopes of authentication
"""
kwargs.setdefault("enable_cae", self._enable_cae)
async with self._lock:
self._token = await await_result(self._credential.get_token, *scopes, **kwargs)
request.http_request.headers["Authorization"] = "Bearer " + cast(AccessToken, self._token).token
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@
# --------------------------------------------------------------------------
import base64
import time
from typing import Optional, TypeVar
from typing import Optional, TypeVar, TYPE_CHECKING

from azure.core.pipeline.policies import BearerTokenCredentialPolicy, SansIOHTTPPolicy
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.exceptions import ServiceRequestError

if TYPE_CHECKING:
# pylint:disable=unused-import
from azure.core.credentials import TokenCredential


HTTPRequestType = TypeVar("HTTPRequestType")
HTTPResponseType = TypeVar("HTTPResponseType")
Expand All @@ -46,6 +50,10 @@ class ARMChallengeAuthenticationPolicy(BearerTokenCredentialPolicy):
:param str scopes: required authentication scopes
"""

def __init__(self, credential: TokenCredential, *scopes: str, **kwargs) -> None: # pylint:disable=unused-argument
kwargs.setdefault("enable_cae", True) # ARM supports Continuous Access Evaluation (CAE).
super().__init__(credential, *scopes, **kwargs)

def on_challenge(
self,
request: PipelineRequest[HTTPRequestType],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
from typing import TypeVar, Awaitable, Optional
from typing import Any, TypeVar, Awaitable, Optional, TYPE_CHECKING
import inspect

from azure.core.pipeline.policies import (
Expand All @@ -34,6 +34,9 @@

from ._authentication import _parse_claims_challenge, _AuxiliaryAuthenticationPolicyBase

if TYPE_CHECKING:
from azure.core.credentials_async import AsyncTokenCredential


HTTPRequestType = TypeVar("HTTPRequestType")
AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType")
Expand All @@ -57,6 +60,10 @@ class AsyncARMChallengeAuthenticationPolicy(AsyncBearerTokenCredentialPolicy):
:param str scopes: required authentication scopes
"""

def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None:
kwargs.setdefault("enable_cae", True) # ARM supports Continuous Access Evaluation (CAE).
super().__init__(credential, *scopes, **kwargs)

# pylint:disable=unused-argument
async def on_challenge(
self,
Expand Down
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

### Breaking Changes

- CP1 client capabilities (CAE) is no longer always-on by default for user credentials. This capability will now be configured as-needed in each `get_token` request by each SDK.

### Bugs Fixed

### Other Changes
Expand Down
5 changes: 3 additions & 2 deletions sdk/identity/azure-identity/azure/identity/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
DEFAULT_REFRESH_OFFSET = 300
DEFAULT_TOKEN_REFRESH_RETRY_DELAY = 30

CACHE_PRIMARY_SUFFIX = ""
CACHE_CAE_SUFFIX = ".cae"


class AzureAuthorityHosts:
AZURE_CHINA = "login.chinacloudapi.cn"
Expand Down Expand Up @@ -50,5 +53,3 @@ class EnvironmentVariables:

AZURE_FEDERATED_TOKEN_FILE = "AZURE_FEDERATED_TOKEN_FILE"
WORKLOAD_IDENTITY_VARS = (AZURE_AUTHORITY_HOST, AZURE_TENANT_ID, AZURE_FEDERATED_TOKEN_FILE)

AZURE_IDENTITY_DISABLE_CP1 = "AZURE_IDENTITY_DISABLE_CP1"
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str claims: additional claims required in the token, such as those returned in a resource provider's
claims challenge following an authorization failure
:keyword bool enable_cae: enables configuring "CP1" client capabilities to support Continuous
Access Evaluation (CAE). Defaults to False.
:rtype: :class:`azure.core.credentials.AccessToken`
:raises ~azure.identity.CredentialUnavailableError: the cache is unavailable or contains insufficient user
information
Expand Down
51 changes: 36 additions & 15 deletions sdk/identity/azure-identity/azure/identity/_credentials/silent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import os
import platform
import time
from typing import Dict, Optional, Any
Expand All @@ -17,7 +16,7 @@
from .._internal.msal_client import MsalClient
from .._internal.shared_token_cache import NO_TOKEN
from .._persistent_cache import _load_persistent_cache, TokenCachePersistenceOptions
from .._constants import EnvironmentVariables
from .._constants import CACHE_CAE_SUFFIX, CACHE_PRIMARY_SUFFIX
from .. import AuthenticationRecord


Expand All @@ -37,8 +36,13 @@ def __init__(
self._tenant_id = tenant_id or self._auth_record.tenant_id
validate_tenant_id(self._tenant_id)
self._cache = kwargs.pop("_cache", None)
self._cae_cache = kwargs.pop("_cae_cache", None)

self._cache_persistence_options = kwargs.pop("cache_persistence_options", None)

self._client_applications: Dict[str, PublicClientApplication] = {}
self._cae_client_applications: Dict[str, PublicClientApplication] = {}

self._additionally_allowed_tenants = kwargs.pop("additionally_allowed_tenants", [])
self._client = MsalClient(**kwargs)
self._initialized = False
Expand All @@ -63,15 +67,23 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
return self._acquire_token_silent(*scopes, **kwargs)

def _initialize(self):
if not self._cache and platform.system() in {"Darwin", "Linux", "Windows"}:

# If no cache options were provided, the default cache will be used. This credential accepts the
# user's default cache regardless of whether it's encrypted. It doesn't create a new cache. If the
# default cache exists, the user must have created it earlier. If it's unencrypted, the user must
# have allowed that.
cache_options = self._cache_persistence_options or TokenCachePersistenceOptions(allow_unencrypted_storage=True)
is_platform_supported = platform.system() in {"Darwin", "Linux", "Windows"}

if not self._cache and is_platform_supported:
try:
self._cache = _load_persistent_cache(cache_options, cache_suffix=CACHE_PRIMARY_SUFFIX)
except Exception: # pylint:disable=broad-except
pass

if not self._cae_cache and is_platform_supported:
try:
# If no cache options were provided, the default cache will be used. This credential accepts the
# user's default cache regardless of whether it's encrypted. It doesn't create a new cache. If the
# default cache exists, the user must have created it earlier. If it's unencrypted, the user must
# have allowed that.
options = self._cache_persistence_options or \
TokenCachePersistenceOptions(allow_unencrypted_storage=True)
self._cache = _load_persistent_cache(options)
self._cae_cache = _load_persistent_cache(cache_options, cache_suffix=CACHE_CAE_SUFFIX)
except Exception: # pylint:disable=broad-except
pass

Expand All @@ -83,17 +95,26 @@ def _get_client_application(self, **kwargs: Any):
additionally_allowed_tenants=self._additionally_allowed_tenants,
**kwargs
)
if tenant_id not in self._client_applications:

client_applications_map = self._client_applications
capabilities = None
token_cache = self._cache

if kwargs.get("enable_cae"):
client_applications_map = self._cae_client_applications
# CP1 = can handle claims challenges (CAE)
capabilities = None if EnvironmentVariables.AZURE_IDENTITY_DISABLE_CP1 in os.environ else ["CP1"]
self._client_applications[tenant_id] = PublicClientApplication(
capabilities = ["CP1"]
token_cache = self._cae_cache

if tenant_id not in client_applications_map:
client_applications_map[tenant_id] = PublicClientApplication(
client_id=self._auth_record.client_id,
authority="https://{}/{}".format(self._auth_record.authority, tenant_id),
token_cache=self._cache,
token_cache=token_cache,
http_client=self._client,
client_capabilities=capabilities
)
return self._client_applications[tenant_id]
return client_applications_map[tenant_id]

@wrap_exceptions
def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessToken:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,4 @@ def _run_pipeline(self, request: HttpRequest, **kwargs: Any) -> AccessToken:
kwargs.pop("claims", None)
now = int(time.time())
response = self._pipeline.run(request, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)
return self._process_response(response, now, **kwargs)
Loading

0 comments on commit e31f36f

Please sign in to comment.