diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 997ff3df4e4d..d97d0fdd741b 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -4,11 +4,7 @@ ### Features Added -### Breaking Changes - -### Bugs Fixed - -### Other Changes +- `OnBehalfOfCredential` now supports client assertion callbacks through the `client_assertion_func` keyword argument. This enables authenticating with client assertions such as federated credentials. ([#35812](https://github.com/Azure/azure-sdk-for-python/pull/35812)) ## 1.17.0b1 (2024-05-13) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/on_behalf_of.py b/sdk/identity/azure-identity/azure/identity/_credentials/on_behalf_of.py index e55e77c7138c..28d368413d3b 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/on_behalf_of.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/on_behalf_of.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import time -from typing import Any, Optional +from typing import Any, Optional, Callable, Union, Dict import msal @@ -28,14 +28,18 @@ class OnBehalfOfCredential(MsalCredential, GetTokenMixin): description of the on-behalf-of flow. :param str tenant_id: ID of the service principal's tenant. Also called its "directory" ID. - :param str client_id: The service principal's client ID + :param str client_id: The service principal's client ID. :keyword str client_secret: Optional. A client secret to authenticate the service principal. - Either **client_secret** or **client_certificate** must be provided. + One of **client_secret**, **client_certificate**, or **client_assertion_func** must be provided. :keyword bytes client_certificate: Optional. The bytes of a certificate in PEM or PKCS12 format including - the private key to authenticate the service principal. Either **client_secret** or **client_certificate** must - be provided. + the private key to authenticate the service principal. One of **client_secret**, **client_certificate**, + or **client_assertion_func** must be provided. + :keyword client_assertion_func: Optional. Function that returns client assertions that authenticate the + application to Microsoft Entra ID. This function is called each time the credential requests a token. It must + return a valid assertion for the target resource. + :paramtype client_assertion_func: Callable[[], str] :keyword str user_assertion: Required. The access token the credential will use as the user assertion when - requesting on-behalf-of tokens + requesting on-behalf-of tokens. :keyword str authority: Authority of a Microsoft Entra endpoint, for example "login.microsoftonline.com", the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` @@ -65,14 +69,31 @@ class OnBehalfOfCredential(MsalCredential, GetTokenMixin): :caption: Create an OnBehalfOfCredential. """ - def __init__(self, tenant_id: str, client_id: str, **kwargs: Any) -> None: - self._assertion = kwargs.pop("user_assertion", None) + def __init__( + self, + tenant_id: str, + client_id: str, + *, + client_certificate: Optional[bytes] = None, + client_secret: Optional[str] = None, + client_assertion_func: Optional[Callable[[], str]] = None, + user_assertion: str, + **kwargs: Any + ) -> None: + self._assertion = user_assertion if not self._assertion: - raise TypeError('"user_assertion" is required.') - client_certificate = kwargs.pop("client_certificate", None) - client_secret = kwargs.pop("client_secret", None) + raise TypeError('"user_assertion" must not be empty.') - if client_certificate: + if client_assertion_func: + if client_certificate or client_secret: + raise ValueError( + "It is invalid to specify more than one of the following: " + '"client_assertion_func", "client_certificate" or "client_secret".' + ) + credential: Union[str, Dict[str, Any]] = { + "client_assertion": client_assertion_func, + } + elif client_certificate: if client_secret: raise ValueError('Specifying both "client_certificate" and "client_secret" is not valid.') try: @@ -86,7 +107,7 @@ def __init__(self, tenant_id: str, client_id: str, **kwargs: Any) -> None: elif client_secret: credential = client_secret else: - raise TypeError('Either "client_certificate" or "client_secret" must be provided') + raise TypeError('Either "client_certificate", "client_secret", or "client_assertion_func" must be provided') super(OnBehalfOfCredential, self).__init__( client_id=client_id, client_credential=credential, tenant_id=tenant_id, **kwargs diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py index 5c8749ab696c..7fc6655cc7f2 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py @@ -267,7 +267,7 @@ def _get_client_secret_request(self, scopes: Iterable[str], secret: str, **kwarg def _get_on_behalf_of_request( self, scopes: Iterable[str], - client_credential: Union[str, AadClientCertificate], + client_credential: Union[str, AadClientCertificate, Dict[str, Any]], user_assertion: str, **kwargs: Any ) -> HttpRequest: @@ -288,6 +288,10 @@ def _get_on_behalf_of_request( if isinstance(client_credential, AadClientCertificate): data["client_assertion"] = self._get_client_certificate_assertion(client_credential) data["client_assertion_type"] = JWT_BEARER_ASSERTION + elif isinstance(client_credential, dict): + func = client_credential["client_assertion"] + data["client_assertion"] = func() + data["client_assertion_type"] = JWT_BEARER_ASSERTION else: data["client_secret"] = client_credential @@ -318,7 +322,7 @@ def _get_refresh_token_request(self, scopes: Iterable[str], refresh_token: str, def _get_refresh_token_on_behalf_of_request( self, scopes: Iterable[str], - client_credential: Union[str, AadClientCertificate], + client_credential: Union[str, AadClientCertificate, Dict[str, Any]], refresh_token: str, **kwargs: Any ) -> HttpRequest: @@ -338,6 +342,10 @@ def _get_refresh_token_on_behalf_of_request( if isinstance(client_credential, AadClientCertificate): data["client_assertion"] = self._get_client_certificate_assertion(client_credential) data["client_assertion_type"] = JWT_BEARER_ASSERTION + elif isinstance(client_credential, dict): + func = client_credential["client_assertion"] + data["client_assertion"] = func() + data["client_assertion_type"] = JWT_BEARER_ASSERTION else: data["client_secret"] = client_credential request = self._post(data, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py index a44d7e6012fc..13587de59645 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py @@ -25,7 +25,7 @@ class MsalCredential: # pylint: disable=too-many-instance-attributes def __init__( self, client_id: str, - client_credential: Optional[Union[str, Dict[str, str]]] = None, + client_credential: Optional[Union[str, Dict[str, Any]]] = None, *, additionally_allowed_tenants: Optional[List[str]] = None, authority: Optional[str] = None, diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/on_behalf_of.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/on_behalf_of.py index 4fbe8fb81b05..55dd007458db 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/on_behalf_of.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/on_behalf_of.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import logging -from typing import Optional, Union, Any +from typing import Optional, Union, Any, Dict, Callable from azure.core.exceptions import ClientAuthenticationError from azure.core.credentials import AccessToken @@ -25,14 +25,18 @@ class OnBehalfOfCredential(AsyncContextManager, GetTokenMixin): description of the on-behalf-of flow. :param str tenant_id: ID of the service principal's tenant. Also called its "directory" ID. - :param str client_id: The service principal's client ID + :param str client_id: The service principal's client ID. :keyword str client_secret: Optional. A client secret to authenticate the service principal. - Either **client_secret** or **client_certificate** must be provided. + One of **client_secret**, **client_certificate**, or **client_assertion_func** must be provided. :keyword bytes client_certificate: Optional. The bytes of a certificate in PEM or PKCS12 format including - the private key to authenticate the service principal. Either **client_secret** or **client_certificate** must - be provided. + the private key to authenticate the service principal. One of **client_secret**, **client_certificate**, + or **client_assertion_func** must be provided. + :keyword client_assertion_func: Optional. Function that returns client assertions that authenticate the + application to Microsoft Entra ID. This function is called each time the credential requests a token. It must + return a valid assertion for the target resource. + :paramtype client_assertion_func: Callable[[], str] :keyword str user_assertion: Required. The access token the credential will use as the user assertion when - requesting on-behalf-of tokens + requesting on-behalf-of tokens. :keyword str authority: Authority of a Microsoft Entra endpoint, for example "login.microsoftonline.com", the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` @@ -62,6 +66,7 @@ def __init__( *, client_certificate: Optional[bytes] = None, client_secret: Optional[str] = None, + client_assertion_func: Optional[Callable[[], str]] = None, user_assertion: str, **kwargs: Any ) -> None: @@ -69,8 +74,19 @@ def __init__( validate_tenant_id(tenant_id) self._assertion = user_assertion - - if client_certificate: + if not self._assertion: + raise TypeError('"user_assertion" must not be empty.') + + if client_assertion_func: + if client_certificate or client_secret: + raise ValueError( + "It is invalid to specify more than one of the following: " + '"client_assertion_func", "client_certificate" or "client_secret".' + ) + self._client_credential: Union[str, AadClientCertificate, Dict[str, Any]] = { + "client_assertion": client_assertion_func, + } + elif client_certificate: if client_secret: raise ValueError('Specifying both "client_certificate" and "client_secret" is not valid.') try: @@ -78,13 +94,11 @@ def __init__( except ValueError as ex: message = '"client_certificate" is not a valid certificate in PEM or PKCS12 format' raise ValueError(message) from ex - self._client_credential: Union[str, AadClientCertificate] = AadClientCertificate( - cert["private_key"], password=cert.get("passphrase") - ) + self._client_credential = AadClientCertificate(cert["private_key"], password=cert.get("passphrase")) elif client_secret: self._client_credential = client_secret else: - raise TypeError('Either "client_certificate" or "client_secret" must be provided') + raise TypeError('Either "client_certificate", "client_secret", or "client_assertion_func" must be provided') # note AadClient handles "authority" and any pipeline kwargs self._client = AadClient(tenant_id, client_id, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py index f556bc2cee8e..2f542747aab4 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import time -from typing import Iterable, Optional, Union +from typing import Iterable, Optional, Union, Dict, Any from azure.core.credentials import AccessToken from azure.core.pipeline import AsyncPipeline @@ -57,7 +57,11 @@ async def obtain_token_by_refresh_token(self, scopes: Iterable[str], refresh_tok return await self._run_pipeline(request, **kwargs) async def obtain_token_by_refresh_token_on_behalf_of( # pylint: disable=name-too-long - self, scopes: Iterable[str], client_credential: Union[str, AadClientCertificate], refresh_token: str, **kwargs + self, + scopes: Iterable[str], + client_credential: Union[str, AadClientCertificate, Dict[str, Any]], + refresh_token: str, + **kwargs ) -> AccessToken: request = self._get_refresh_token_on_behalf_of_request( scopes, client_credential=client_credential, refresh_token=refresh_token, **kwargs @@ -65,7 +69,11 @@ async def obtain_token_by_refresh_token_on_behalf_of( # pylint: disable=name-to return await self._run_pipeline(request, **kwargs) async def obtain_token_on_behalf_of( - self, scopes: Iterable[str], client_credential: Union[str, AadClientCertificate], user_assertion: str, **kwargs + self, + scopes: Iterable[str], + client_credential: Union[str, AadClientCertificate, Dict[str, Any]], + user_assertion: str, + **kwargs ) -> AccessToken: request = self._get_on_behalf_of_request( scopes=scopes, client_credential=client_credential, user_assertion=user_assertion, **kwargs diff --git a/sdk/identity/azure-identity/samples/on_behalf_of_client_assertion.py b/sdk/identity/azure-identity/samples/on_behalf_of_client_assertion.py new file mode 100644 index 000000000000..03fe4c3b1d4a --- /dev/null +++ b/sdk/identity/azure-identity/samples/on_behalf_of_client_assertion.py @@ -0,0 +1,43 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +""" +FILE: on_behalf_of_client_assertion.py +DESCRIPTION: + This sample demonstrates the use of OnBehalfOfCredential to authenticate the Key Vault SecretClient using a managed + identity as the client assertion. More information about the On-Behalf-Of flow can be found here: + https://learn.microsoft.com/entra/identity-platform/v2-oauth2-on-behalf-of-flow. +USAGE: + python on_behalf_of_client_assertion.py + +**Note** - This sample requires the `azure-keyvault-secrets` package. +""" +# [START obo_client_assertion] +from azure.identity import OnBehalfOfCredential, ManagedIdentityCredential +from azure.keyvault.secrets import SecretClient + + +# Replace the following variables with your own values. +tenant_id = "" +client_id = "" +user_assertion = "" + +managed_identity_credential = ManagedIdentityCredential() + + +def get_managed_identity_token() -> str: + # This function should return an access token obtained from a managed identity. + access_token = managed_identity_credential.get_token("api://AzureADTokenExchange") + return access_token.token + + +credential = OnBehalfOfCredential( + tenant_id=tenant_id, + client_id=client_id, + user_assertion=user_assertion, + client_assertion_func=get_managed_identity_token, +) + +client = SecretClient(vault_url="https://.vault.azure.net/", credential=credential) +# [END obo_client_assertion] diff --git a/sdk/identity/azure-identity/tests/test_obo.py b/sdk/identity/azure-identity/tests/test_obo.py index 4761dc40a50a..ea369310877b 100644 --- a/sdk/identity/azure-identity/tests/test_obo.py +++ b/sdk/identity/azure-identity/tests/test_obo.py @@ -12,6 +12,7 @@ from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy from azure.identity import OnBehalfOfCredential, UsernamePasswordCredential from azure.identity._constants import EnvironmentVariables +from azure.identity._internal.aad_client_base import JWT_BEARER_ASSERTION from azure.identity._internal.user_agent import USER_AGENT import pytest from urllib.parse import urlparse @@ -228,3 +229,53 @@ def test_no_client_credential(): """The credential should raise ValueError when ctoring with no client_secret or client_certificate""" with pytest.raises(TypeError): credential = OnBehalfOfCredential("tenant-id", "client-id", user_assertion="assertion") + + +def test_client_assertion_func(): + """The credential should accept a client_assertion_func""" + expected_client_assertion = "client-assertion" + expected_user_assertion = "user-assertion" + expected_token = "***" + func_call_count = 0 + + def client_assertion_func(): + nonlocal func_call_count + func_call_count += 1 + return expected_client_assertion + + def send(request, **kwargs): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + if "/oauth2/v2.0/token" not in parsed.path: + return get_discovery_response("https://{}/{}".format(parsed.netloc, tenant)) + + assert request.data.get("client_assertion") == expected_client_assertion + assert request.data.get("client_assertion_type") == JWT_BEARER_ASSERTION + assert request.data.get("assertion") == expected_user_assertion + return mock_response(json_payload=build_aad_response(access_token=expected_token)) + + transport = Mock(send=Mock(wraps=send)) + credential = OnBehalfOfCredential( + "tenant-id", + "client-id", + client_assertion_func=client_assertion_func, + user_assertion=expected_user_assertion, + transport=transport, + ) + + access_token = credential.get_token("scope") + assert access_token.token == expected_token + assert func_call_count == 1 + + +def test_client_assertion_func_with_client_certificate(): + """The credential should raise ValueError when ctoring with both client_assertion_func and client_certificate""" + with pytest.raises(ValueError) as ex: + credential = OnBehalfOfCredential( + "tenant-id", + "client-id", + client_assertion_func=lambda: "client-assertion", + client_certificate=b"certificate", + user_assertion="assertion", + ) + assert "It is invalid to specify more than one of the following" in str(ex.value) diff --git a/sdk/identity/azure-identity/tests/test_obo_async.py b/sdk/identity/azure-identity/tests/test_obo_async.py index 8c143d4b72e2..f9d80323e84f 100644 --- a/sdk/identity/azure-identity/tests/test_obo_async.py +++ b/sdk/identity/azure-identity/tests/test_obo_async.py @@ -12,6 +12,7 @@ from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy from azure.identity import UsernamePasswordCredential from azure.identity._constants import EnvironmentVariables +from azure.identity._internal.aad_client_base import JWT_BEARER_ASSERTION from azure.identity._internal.user_agent import USER_AGENT from azure.identity.aio import OnBehalfOfCredential import pytest @@ -305,3 +306,54 @@ async def test_no_client_credential(): """The credential should raise ValueError when ctoring with no client_secret or client_certificate""" with pytest.raises(TypeError): credential = OnBehalfOfCredential("tenant-id", "client-id", user_assertion="assertion") + + +@pytest.mark.asyncio +async def test_client_assertion_func(): + """The credential should accept a client_assertion_func""" + expected_client_assertion = "client-assertion" + expected_user_assertion = "user-assertion" + expected_token = "***" + func_call_count = 0 + + def client_assertion_func(): + nonlocal func_call_count + func_call_count += 1 + return expected_client_assertion + + async def send(request, **kwargs): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + if "/oauth2/v2.0/token" not in parsed.path: + return get_discovery_response("https://{}/{}".format(parsed.netloc, tenant)) + + assert request.data.get("client_assertion") == expected_client_assertion + assert request.data.get("client_assertion_type") == JWT_BEARER_ASSERTION + assert request.data.get("assertion") == expected_user_assertion + return mock_response(json_payload=build_aad_response(access_token=expected_token)) + + transport = Mock(send=send) + credential = OnBehalfOfCredential( + "tenant-id", + "client-id", + client_assertion_func=client_assertion_func, + user_assertion=expected_user_assertion, + transport=transport, + ) + token = await credential.get_token("scope") + assert token.token == expected_token + assert func_call_count == 1 + + +@pytest.mark.asyncio +async def test_client_assertion_func_with_client_certificate(): + """The credential should raise when given both client_assertion_func and client_certificate""" + with pytest.raises(ValueError) as ex: + OnBehalfOfCredential( + "tenant-id", + "client-id", + client_assertion_func=lambda: "client-assertion", + client_certificate=b"cert", + user_assertion="assertion", + ) + assert "It is invalid to specify more than one of the following" in str(ex.value)