Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Identity] Allow use of client assertion in OBO cred #35812

Merged
merged 3 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@

### Features Added

### Breaking Changes

### Bugs Fixed

### Other Changes
- `OnBehalfOfCredential` now supports client assertion callbacks through the `client_assertion_func` keyword argument. This enables authenticating with client assertions such as federated credentials. ([#35812](https://github.com/Azure/azure-sdk-for-python/pull/35812))

## 1.17.0b1 (2024-05-13)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License.
# ------------------------------------
import time
from typing import Any, Optional
from typing import Any, Optional, Callable, Union, Dict

import msal

Expand All @@ -28,14 +28,18 @@ class OnBehalfOfCredential(MsalCredential, GetTokenMixin):
description of the on-behalf-of flow.

:param str tenant_id: ID of the service principal's tenant. Also called its "directory" ID.
:param str client_id: The service principal's client ID
:param str client_id: The service principal's client ID.
:keyword str client_secret: Optional. A client secret to authenticate the service principal.
Either **client_secret** or **client_certificate** must be provided.
One of **client_secret**, **client_certificate**, or **client_assertion_func** must be provided.
:keyword bytes client_certificate: Optional. The bytes of a certificate in PEM or PKCS12 format including
the private key to authenticate the service principal. Either **client_secret** or **client_certificate** must
be provided.
the private key to authenticate the service principal. One of **client_secret**, **client_certificate**,
or **client_assertion_func** must be provided.
:keyword client_assertion_func: Optional. Function that returns client assertions that authenticate the
application to Microsoft Entra ID. This function is called each time the credential requests a token. It must
return a valid assertion for the target resource.
:paramtype client_assertion_func: Callable[[], str]
:keyword str user_assertion: Required. The access token the credential will use as the user assertion when
requesting on-behalf-of tokens
requesting on-behalf-of tokens.

:keyword str authority: Authority of a Microsoft Entra endpoint, for example "login.microsoftonline.com",
the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts`
Expand Down Expand Up @@ -65,14 +69,31 @@ class OnBehalfOfCredential(MsalCredential, GetTokenMixin):
:caption: Create an OnBehalfOfCredential.
"""

def __init__(self, tenant_id: str, client_id: str, **kwargs: Any) -> None:
self._assertion = kwargs.pop("user_assertion", None)
def __init__(
self,
tenant_id: str,
client_id: str,
*,
client_certificate: Optional[bytes] = None,
client_secret: Optional[str] = None,
client_assertion_func: Optional[Callable[[], str]] = None,
user_assertion: str,
**kwargs: Any
) -> None:
self._assertion = user_assertion
if not self._assertion:
raise TypeError('"user_assertion" is required.')
client_certificate = kwargs.pop("client_certificate", None)
client_secret = kwargs.pop("client_secret", None)
raise TypeError('"user_assertion" must not be empty.')

if client_certificate:
if client_assertion_func:
if client_certificate or client_secret:
raise ValueError(
"It is invalid to specify more than one of the following: "
'"client_assertion_func", "client_certificate" or "client_secret".'
)
credential: Union[str, Dict[str, Any]] = {
"client_assertion": client_assertion_func,
xiangyan99 marked this conversation as resolved.
Show resolved Hide resolved
}
elif client_certificate:
if client_secret:
raise ValueError('Specifying both "client_certificate" and "client_secret" is not valid.')
try:
Expand All @@ -86,7 +107,7 @@ def __init__(self, tenant_id: str, client_id: str, **kwargs: Any) -> None:
elif client_secret:
credential = client_secret
else:
raise TypeError('Either "client_certificate" or "client_secret" must be provided')
raise TypeError('Either "client_certificate", "client_secret", or "client_assertion_func" must be provided')

super(OnBehalfOfCredential, self).__init__(
client_id=client_id, client_credential=credential, tenant_id=tenant_id, **kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def _get_client_secret_request(self, scopes: Iterable[str], secret: str, **kwarg
def _get_on_behalf_of_request(
self,
scopes: Iterable[str],
client_credential: Union[str, AadClientCertificate],
client_credential: Union[str, AadClientCertificate, Dict[str, Any]],
user_assertion: str,
**kwargs: Any
) -> HttpRequest:
Expand All @@ -288,6 +288,10 @@ def _get_on_behalf_of_request(
if isinstance(client_credential, AadClientCertificate):
data["client_assertion"] = self._get_client_certificate_assertion(client_credential)
data["client_assertion_type"] = JWT_BEARER_ASSERTION
elif isinstance(client_credential, dict):
func = client_credential["client_assertion"]
data["client_assertion"] = func()
data["client_assertion_type"] = JWT_BEARER_ASSERTION
else:
data["client_secret"] = client_credential

Expand Down Expand Up @@ -318,7 +322,7 @@ def _get_refresh_token_request(self, scopes: Iterable[str], refresh_token: str,
def _get_refresh_token_on_behalf_of_request(
self,
scopes: Iterable[str],
client_credential: Union[str, AadClientCertificate],
client_credential: Union[str, AadClientCertificate, Dict[str, Any]],
refresh_token: str,
**kwargs: Any
) -> HttpRequest:
Expand All @@ -338,6 +342,10 @@ def _get_refresh_token_on_behalf_of_request(
if isinstance(client_credential, AadClientCertificate):
data["client_assertion"] = self._get_client_certificate_assertion(client_credential)
data["client_assertion_type"] = JWT_BEARER_ASSERTION
elif isinstance(client_credential, dict):
func = client_credential["client_assertion"]
data["client_assertion"] = func()
data["client_assertion_type"] = JWT_BEARER_ASSERTION
else:
data["client_secret"] = client_credential
request = self._post(data, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class MsalCredential: # pylint: disable=too-many-instance-attributes
def __init__(
self,
client_id: str,
client_credential: Optional[Union[str, Dict[str, str]]] = None,
client_credential: Optional[Union[str, Dict[str, Any]]] = None,
*,
additionally_allowed_tenants: Optional[List[str]] = None,
authority: Optional[str] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License.
# ------------------------------------
import logging
from typing import Optional, Union, Any
from typing import Optional, Union, Any, Dict, Callable

from azure.core.exceptions import ClientAuthenticationError
from azure.core.credentials import AccessToken
Expand All @@ -25,14 +25,18 @@ class OnBehalfOfCredential(AsyncContextManager, GetTokenMixin):
description of the on-behalf-of flow.

:param str tenant_id: ID of the service principal's tenant. Also called its "directory" ID.
:param str client_id: The service principal's client ID
:param str client_id: The service principal's client ID.
:keyword str client_secret: Optional. A client secret to authenticate the service principal.
Either **client_secret** or **client_certificate** must be provided.
One of **client_secret**, **client_certificate**, or **client_assertion_func** must be provided.
:keyword bytes client_certificate: Optional. The bytes of a certificate in PEM or PKCS12 format including
the private key to authenticate the service principal. Either **client_secret** or **client_certificate** must
be provided.
the private key to authenticate the service principal. One of **client_secret**, **client_certificate**,
or **client_assertion_func** must be provided.
:keyword client_assertion_func: Optional. Function that returns client assertions that authenticate the
application to Microsoft Entra ID. This function is called each time the credential requests a token. It must
return a valid assertion for the target resource.
:paramtype client_assertion_func: Callable[[], str]
:keyword str user_assertion: Required. The access token the credential will use as the user assertion when
requesting on-behalf-of tokens
requesting on-behalf-of tokens.

:keyword str authority: Authority of a Microsoft Entra endpoint, for example "login.microsoftonline.com",
the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts`
Expand Down Expand Up @@ -62,29 +66,39 @@ def __init__(
*,
client_certificate: Optional[bytes] = None,
client_secret: Optional[str] = None,
client_assertion_func: Optional[Callable[[], str]] = None,
user_assertion: str,
**kwargs: Any
) -> None:
super().__init__()
validate_tenant_id(tenant_id)

self._assertion = user_assertion

if client_certificate:
if not self._assertion:
raise TypeError('"user_assertion" must not be empty.')

if client_assertion_func:
if client_certificate or client_secret:
raise ValueError(
"It is invalid to specify more than one of the following: "
'"client_assertion_func", "client_certificate" or "client_secret".'
)
self._client_credential: Union[str, AadClientCertificate, Dict[str, Any]] = {
"client_assertion": client_assertion_func,
}
elif client_certificate:
if client_secret:
raise ValueError('Specifying both "client_certificate" and "client_secret" is not valid.')
try:
cert = get_client_credential(None, kwargs.pop("password", None), client_certificate)
except ValueError as ex:
message = '"client_certificate" is not a valid certificate in PEM or PKCS12 format'
raise ValueError(message) from ex
self._client_credential: Union[str, AadClientCertificate] = AadClientCertificate(
cert["private_key"], password=cert.get("passphrase")
)
self._client_credential = AadClientCertificate(cert["private_key"], password=cert.get("passphrase"))
elif client_secret:
self._client_credential = client_secret
else:
raise TypeError('Either "client_certificate" or "client_secret" must be provided')
raise TypeError('Either "client_certificate", "client_secret", or "client_assertion_func" must be provided')

# note AadClient handles "authority" and any pipeline kwargs
self._client = AadClient(tenant_id, client_id, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License.
# ------------------------------------
import time
from typing import Iterable, Optional, Union
from typing import Iterable, Optional, Union, Dict, Any

from azure.core.credentials import AccessToken
from azure.core.pipeline import AsyncPipeline
Expand Down Expand Up @@ -57,15 +57,23 @@ async def obtain_token_by_refresh_token(self, scopes: Iterable[str], refresh_tok
return await self._run_pipeline(request, **kwargs)

async def obtain_token_by_refresh_token_on_behalf_of( # pylint: disable=name-too-long
self, scopes: Iterable[str], client_credential: Union[str, AadClientCertificate], refresh_token: str, **kwargs
self,
scopes: Iterable[str],
client_credential: Union[str, AadClientCertificate, Dict[str, Any]],
refresh_token: str,
**kwargs
) -> AccessToken:
request = self._get_refresh_token_on_behalf_of_request(
scopes, client_credential=client_credential, refresh_token=refresh_token, **kwargs
)
return await self._run_pipeline(request, **kwargs)

async def obtain_token_on_behalf_of(
self, scopes: Iterable[str], client_credential: Union[str, AadClientCertificate], user_assertion: str, **kwargs
self,
scopes: Iterable[str],
client_credential: Union[str, AadClientCertificate, Dict[str, Any]],
user_assertion: str,
**kwargs
) -> AccessToken:
request = self._get_on_behalf_of_request(
scopes=scopes, client_credential=client_credential, user_assertion=user_assertion, **kwargs
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
"""
FILE: on_behalf_of_client_assertion.py
DESCRIPTION:
This sample demonstrates the use of OnBehalfOfCredential to authenticate the Key Vault SecretClient using a managed
identity as the client assertion. More information about the On-Behalf-Of flow can be found here:
https://learn.microsoft.com/entra/identity-platform/v2-oauth2-on-behalf-of-flow.
USAGE:
python on_behalf_of_client_assertion.py

**Note** - This sample requires the `azure-keyvault-secrets` package.
"""
# [START obo_client_assertion]
from azure.identity import OnBehalfOfCredential, ManagedIdentityCredential
from azure.keyvault.secrets import SecretClient


# Replace the following variables with your own values.
tenant_id = "<tenant_id>"
client_id = "<client_id>"
user_assertion = "<user_assertion>"

managed_identity_credential = ManagedIdentityCredential()


def get_managed_identity_token() -> str:
# This function should return an access token obtained from a managed identity.
access_token = managed_identity_credential.get_token("api://AzureADTokenExchange")
return access_token.token


credential = OnBehalfOfCredential(
tenant_id=tenant_id,
client_id=client_id,
user_assertion=user_assertion,
client_assertion_func=get_managed_identity_token,
)

client = SecretClient(vault_url="https://<your-key-vault-name>.vault.azure.net/", credential=credential)
# [END obo_client_assertion]
51 changes: 51 additions & 0 deletions sdk/identity/azure-identity/tests/test_obo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy
from azure.identity import OnBehalfOfCredential, UsernamePasswordCredential
from azure.identity._constants import EnvironmentVariables
from azure.identity._internal.aad_client_base import JWT_BEARER_ASSERTION
from azure.identity._internal.user_agent import USER_AGENT
import pytest
from urllib.parse import urlparse
Expand Down Expand Up @@ -228,3 +229,53 @@ def test_no_client_credential():
"""The credential should raise ValueError when ctoring with no client_secret or client_certificate"""
with pytest.raises(TypeError):
credential = OnBehalfOfCredential("tenant-id", "client-id", user_assertion="assertion")


def test_client_assertion_func():
"""The credential should accept a client_assertion_func"""
expected_client_assertion = "client-assertion"
expected_user_assertion = "user-assertion"
expected_token = "***"
func_call_count = 0

def client_assertion_func():
nonlocal func_call_count
func_call_count += 1
return expected_client_assertion

def send(request, **kwargs):
parsed = urlparse(request.url)
tenant = parsed.path.split("/")[1]
if "/oauth2/v2.0/token" not in parsed.path:
return get_discovery_response("https://{}/{}".format(parsed.netloc, tenant))

assert request.data.get("client_assertion") == expected_client_assertion
assert request.data.get("client_assertion_type") == JWT_BEARER_ASSERTION
assert request.data.get("assertion") == expected_user_assertion
return mock_response(json_payload=build_aad_response(access_token=expected_token))

transport = Mock(send=Mock(wraps=send))
credential = OnBehalfOfCredential(
"tenant-id",
"client-id",
client_assertion_func=client_assertion_func,
user_assertion=expected_user_assertion,
transport=transport,
)

access_token = credential.get_token("scope")
assert access_token.token == expected_token
assert func_call_count == 1


def test_client_assertion_func_with_client_certificate():
"""The credential should raise ValueError when ctoring with both client_assertion_func and client_certificate"""
with pytest.raises(ValueError) as ex:
credential = OnBehalfOfCredential(
"tenant-id",
"client-id",
client_assertion_func=lambda: "client-assertion",
client_certificate=b"certificate",
user_assertion="assertion",
)
assert "It is invalid to specify more than one of the following" in str(ex.value)
Loading