Skip to content

Commit

Permalink
Add OnBehalfOfCredential (#20451)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Sep 3, 2021
1 parent 609c4df commit 1f3fe27
Show file tree
Hide file tree
Showing 24 changed files with 1,963 additions and 68 deletions.
4 changes: 3 additions & 1 deletion sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
### Features Added
- `CertificateCredential` accepts certificates in PKCS12 format
([#13540](https://github.com/Azure/azure-sdk-for-python/issues/13540))
- `OnBehalfOfCredential` supports the on-behalf-of authentication flow for
accessing resources on behalf of users
([#19308](https://github.com/Azure/azure-sdk-for-python/issues/19308))

### Breaking Changes

Expand All @@ -17,7 +20,6 @@
([#18798](https://github.com/Azure/azure-sdk-for-python/issues/18798))



## 1.6.1 (2021-08-19)

### Other Changes
Expand Down
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/azure/identity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
EnvironmentCredential,
InteractiveBrowserCredential,
ManagedIdentityCredential,
OnBehalfOfCredential,
SharedTokenCacheCredential,
UsernamePasswordCredential,
VisualStudioCodeCredential,
Expand All @@ -45,6 +46,7 @@
"EnvironmentCredential",
"InteractiveBrowserCredential",
"KnownAuthorities",
"OnBehalfOfCredential",
"RegionalAuthority",
"ManagedIdentityCredential",
"SharedTokenCacheCredential",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .default import DefaultAzureCredential
from .environment import EnvironmentCredential
from .managed_identity import ManagedIdentityCredential
from .on_behalf_of import OnBehalfOfCredential
from .shared_cache import SharedTokenCacheCredential
from .azure_cli import AzureCliCredential
from .device_code import DeviceCodeCredential
Expand All @@ -32,6 +33,7 @@
"EnvironmentCredential",
"InteractiveBrowserCredential",
"ManagedIdentityCredential",
"OnBehalfOfCredential",
"SharedTokenCacheCredential",
"AzureCliCredential",
"UsernamePasswordCredential",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,16 @@ def load_pkcs12_certificate(certificate_data, password):
# type: (bytes, Optional[bytes]) -> _Cert
from cryptography.hazmat.primitives.serialization import Encoding, NoEncryption, pkcs12, PrivateFormat

private_key, cert, additional_certs = pkcs12.load_key_and_certificates(
certificate_data, password, backend=default_backend()
)
try:
private_key, cert, additional_certs = pkcs12.load_key_and_certificates(
certificate_data, password, backend=default_backend()
)
except ValueError as ex:
# mentioning PEM here because we raise this error when certificate_data is garbage
six.raise_from(ValueError("Failed to deserialize certificate in PEM or PKCS12 format"), ex)
if not private_key:
raise ValueError("The certificate must include its private key")
if not cert:
# mentioning PEM here because we raise this error when certificate_data is garbage
raise ValueError("Failed to deserialize certificate in PEM or PKCS12 format")

# This serializes the private key without any encryption it may have had. Doing so doesn't violate security
Expand Down Expand Up @@ -137,7 +140,7 @@ def get_client_credential(certificate_path, password=None, certificate_data=None
password = None # load_pkcs12_certificate returns cert.pem_bytes decrypted

if not isinstance(cert.private_key, RSAPrivateKey):
raise ValueError("CertificateCredential requires an RSA private key because it uses RS256 for signing")
raise ValueError("The certificate must have an RSA private key because RS256 is used for signing")

client_credential = {"private_key": cert.pem_bytes, "thumbprint": hexlify(cert.fingerprint).decode("utf-8")}
if password:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import time
from typing import cast, TYPE_CHECKING

import six

from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError

from .certificate import get_client_credential
from .._internal.decorators import wrap_exceptions
from .._internal.get_token_mixin import GetTokenMixin
from .._internal.interactive import _build_auth_record
from .._internal.msal_credentials import MsalCredential

if TYPE_CHECKING:
from typing import Any, Dict, Optional, Union
import msal
from .. import AuthenticationRecord


class OnBehalfOfCredential(MsalCredential, GetTokenMixin):
"""Authenticates a service principal via the on-behalf-of flow.
This flow is typically used by middle-tier services that authorize requests to other services with a delegated
user identity. Because this is not an interactive authentication flow, an application using it must have admin
consent for any delegated permissions before requesting tokens for them. See `Azure Active Directory documentation
<https://docs.microsoft.com/azure/active-directory/develop/v2-oauth2-on-behalf-of-flow>`_ for a more detailed
description of the on-behalf-of flow.
:param str tenant_id: ID of the service principal's tenant. Also called its "directory" ID.
:param str client_id: the service principal's client ID
:param client_credential: a credential to authenticate the service principal, either one of its client secrets (a
string) or the bytes of a certificate in PEM or PKCS12 format including the private key
:type client_credential: str or bytes
:param str user_assertion: the access token the credential will use as the user assertion when requesting
on-behalf-of tokens
:keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant
the application is registered in. When False, which is the default, the credential will acquire tokens only
from the tenant specified by **tenant_id**.
:keyword str authority: Authority of an Azure Active Directory endpoint, for example "login.microsoftonline.com",
the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts`
defines authorities for other clouds.
:keyword password: a certificate password. Used only when **client_credential** is certificate bytes. If this value
is a unicode string, it will be encoded as UTF-8. If the certificate requires a different encoding, pass
appropriately encoded bytes instead.
:paramtype password: str or bytes
"""

def __init__(self, tenant_id, client_id, client_credential, user_assertion, **kwargs):
# type: (str, str, Union[bytes, str], str, **Any) -> None
credential = cast("Union[Dict, str]", client_credential)
if isinstance(client_credential, six.binary_type):
try:
credential = get_client_credential(
certificate_path=None, password=kwargs.pop("password", None), certificate_data=client_credential
)
except ValueError as ex:
# client_credential isn't a valid cert. On 2.7 str == bytes and we ignore this exception because we
# can't tell whether the caller intended to provide a cert. On Python 3 we can say the caller provided
# either an invalid cert, or a client secret as bytes; both are errors.
if six.PY3:
message = (
'"client_credential" should be either a client secret (a string)'
+ " or the bytes of a certificate in PEM or PKCS12 format"
)
six.raise_from(ValueError(message), ex)

super(OnBehalfOfCredential, self).__init__(client_id, credential, tenant_id=tenant_id, **kwargs)
self._assertion = user_assertion
self._auth_record = None # type: Optional[AuthenticationRecord]

@wrap_exceptions
def _acquire_token_silently(self, *scopes, **kwargs):
# type: (*str, **Any) -> Optional[AccessToken]
if self._auth_record:
claims = kwargs.get("claims")
app = self._get_app(**kwargs)
for account in app.get_accounts(username=self._auth_record.username):
if account.get("home_account_id") != self._auth_record.home_account_id:
continue

now = int(time.time())
result = app.acquire_token_silent_with_error(list(scopes), account=account, claims_challenge=claims)
if result and "access_token" in result and "expires_in" in result:
return AccessToken(result["access_token"], now + int(result["expires_in"]))

return None

@wrap_exceptions
def _request_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
app = self._get_app(**kwargs) # type: msal.ConfidentialClientApplication
request_time = int(time.time())
result = app.acquire_token_on_behalf_of(self._assertion, list(scopes), claims_challenge=kwargs.get("claims"))
if "access_token" not in result or "expires_in" not in result:
message = "Authentication failed: {}".format(result.get("error_description") or result.get("error"))
response = self._client.get_error_response(result)
raise ClientAuthenticationError(message=message, response=response)

try:
self._auth_record = _build_auth_record(result)
except ClientAuthenticationError:
pass # non-fatal; we'll use the assertion again next time instead of a refresh token

return AccessToken(result["access_token"], request_time + int(result["expires_in"]))
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Iterable, Optional
from typing import Any, Iterable, Optional, Union
from azure.core.credentials import AccessToken
from azure.core.pipeline import Pipeline
from .._internal import AadClientCertificate
Expand Down Expand Up @@ -65,6 +65,11 @@ def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
response = self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)

def obtain_token_on_behalf_of(self, scopes, client_credential, user_assertion, **kwargs):
# type: (Iterable[str], Union[str, AadClientCertificate], str, **Any) -> AccessToken
# no need for an implementation, non-async OnBehalfOfCredential acquires tokens through MSAL
raise NotImplementedError()

# pylint:disable=no-self-use
def _build_pipeline(self, **kwargs):
# type: (**Any) -> Pipeline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from azure.core.exceptions import ClientAuthenticationError
from . import get_default_authority, normalize_authority
from .._internal import resolve_tenant
from .._internal.aadclient_certificate import AadClientCertificate

try:
from typing import TYPE_CHECKING
Expand All @@ -34,12 +35,13 @@
from azure.core.pipeline import AsyncPipeline, Pipeline, PipelineResponse
from azure.core.pipeline.policies import AsyncHTTPPolicy, HTTPPolicy, SansIOHTTPPolicy
from azure.core.pipeline.transport import AsyncHttpTransport, HttpTransport
from .._internal import AadClientCertificate

PipelineType = Union[AsyncPipeline, Pipeline]
PolicyType = Union[AsyncHTTPPolicy, HTTPPolicy, SansIOHTTPPolicy]
TransportType = Union[AsyncHttpTransport, HttpTransport]

JWT_BEARER_ASSERTION = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"


class AadClientBase(ABC):
_POST = ["POST"]
Expand Down Expand Up @@ -96,6 +98,10 @@ def obtain_token_by_client_secret(self, scopes, secret, **kwargs):
def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
pass

@abc.abstractmethod
def obtain_token_on_behalf_of(self, scopes, client_credential, user_assertion, **kwargs):
pass

@abc.abstractmethod
def _build_pipeline(self, **kwargs):
pass
Expand Down Expand Up @@ -173,7 +179,7 @@ def _get_jwt_assertion_request(self, scopes, assertion, **kwargs):
# type: (Iterable[str], str, **Any) -> HttpRequest
data = {
"client_assertion": assertion,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion_type": JWT_BEARER_ASSERTION,
"client_id": self._client_id,
"grant_type": "client_credentials",
"scope": " ".join(scopes),
Expand All @@ -182,8 +188,8 @@ def _get_jwt_assertion_request(self, scopes, assertion, **kwargs):
request = self._post(data, **kwargs)
return request

def _get_client_certificate_request(self, scopes, certificate, **kwargs):
# type: (Iterable[str], AadClientCertificate, **Any) -> HttpRequest
def _get_client_certificate_assertion(self, certificate, **kwargs):
# type: (AadClientCertificate, **Any) -> str
now = int(time.time())
header = six.ensure_binary(
json.dumps({"typ": "JWT", "alg": "RS256", "x5t": certificate.thumbprint}), encoding="utf-8"
Expand All @@ -204,8 +210,11 @@ def _get_client_certificate_request(self, scopes, certificate, **kwargs):
jws = base64.urlsafe_b64encode(header) + b"." + base64.urlsafe_b64encode(payload)
signature = certificate.sign(jws)
jwt_bytes = jws + b"." + base64.urlsafe_b64encode(signature)
assertion = jwt_bytes.decode("utf-8")
return jwt_bytes.decode("utf-8")

def _get_client_certificate_request(self, scopes, certificate, **kwargs):
# type: (Iterable[str], AadClientCertificate, **Any) -> HttpRequest
assertion = self._get_client_certificate_assertion(certificate, **kwargs)
return self._get_jwt_assertion_request(scopes, assertion, **kwargs)

def _get_client_secret_request(self, scopes, secret, **kwargs):
Expand All @@ -219,6 +228,24 @@ def _get_client_secret_request(self, scopes, secret, **kwargs):
request = self._post(data, **kwargs)
return request

def _get_on_behalf_of_request(self, scopes, client_credential, user_assertion, **kwargs):
# type: (Iterable[str], Union[str, AadClientCertificate], str, **Any) -> HttpRequest
data = {
"assertion": user_assertion,
"client_id": self._client_id,
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
"requested_token_use": "on_behalf_of",
"scope": " ".join(scopes),
}
if isinstance(client_credential, AadClientCertificate):
data["client_assertion"] = self._get_client_certificate_assertion(client_credential)
data["client_assertion_type"] = JWT_BEARER_ASSERTION
else:
data["client_secret"] = client_credential

request = self._post(data, **kwargs)
return request

def _get_refresh_token_request(self, scopes, refresh_token, **kwargs):
# type: (Iterable[str], str, **Any) -> HttpRequest
data = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, pem_bytes, password=None):
# type: (bytes, Optional[bytes]) -> None
private_key = serialization.load_pem_private_key(pem_bytes, password=password, backend=default_backend())
if not isinstance(private_key, RSAPrivateKey):
raise ValueError("CertificateCredential requires an RSA private key because it uses RS256 for signing")
raise ValueError("The certificate must have an RSA private key because RS256 is used for signing")
self._private_key = private_key

cert = x509.load_pem_x509_certificate(pem_bytes, default_backend())
Expand Down
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/azure/identity/aio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DefaultAzureCredential,
EnvironmentCredential,
ManagedIdentityCredential,
OnBehalfOfCredential,
SharedTokenCacheCredential,
VisualStudioCodeCredential,
)
Expand All @@ -30,6 +31,7 @@
"DefaultAzureCredential",
"EnvironmentCredential",
"ManagedIdentityCredential",
"OnBehalfOfCredential",
"ChainedTokenCredential",
"SharedTokenCacheCredential",
"VisualStudioCodeCredential",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .default import DefaultAzureCredential
from .environment import EnvironmentCredential
from .managed_identity import ManagedIdentityCredential
from .on_behalf_of import OnBehalfOfCredential
from .certificate import CertificateCredential
from .client_secret import ClientSecretCredential
from .shared_cache import SharedTokenCacheCredential
Expand All @@ -27,6 +28,7 @@
"DefaultAzureCredential",
"EnvironmentCredential",
"ManagedIdentityCredential",
"OnBehalfOfCredential",
"SharedTokenCacheCredential",
"VisualStudioCodeCredential",
]
Loading

0 comments on commit 1f3fe27

Please sign in to comment.