Skip to content

Commit

Permalink
Refactor ClientSecretCredential to use AadClient (#11718)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Jun 4, 2020
1 parent d4633cf commit 0ec1601
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from .._authn_client import AuthnClient
from .._base import ClientSecretCredentialBase
from .._internal import AadClient, ClientSecretCredentialBase

try:
from typing import TYPE_CHECKING
Expand All @@ -28,12 +27,7 @@ class ClientSecretCredential(ClientSecretCredentialBase):
defines authorities for other clouds.
"""

def __init__(self, tenant_id, client_id, client_secret, **kwargs):
# type: (str, str, str, **Any) -> None
super(ClientSecretCredential, self).__init__(tenant_id, client_id, client_secret, **kwargs)
self._client = AuthnClient(tenant=tenant_id, **kwargs)

def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
"""Request an access token for `scopes`.
Expand All @@ -48,8 +42,10 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
if not scopes:
raise ValueError("'get_token' requires at least one scope")

token = self._client.get_cached_token(scopes)
token = self._client.get_cached_access_token(scopes)
if not token:
data = dict(self._form_data, scope=" ".join(scopes))
token = self._client.request_token(scopes, form_data=data)
token = self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
return token

def _get_auth_client(self, tenant_id, client_id, **kwargs):
return AadClient(tenant_id, client_id, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def get_default_authority():
from .auth_code_redirect_handler import AuthCodeRedirectServer
from .aadclient_certificate import AadClientCertificate
from .certificate_credential_base import CertificateCredentialBase
from .client_secret_credential_base import ClientSecretCredentialBase
from .exception_wrapper import wrap_exceptions
from .msal_credentials import ConfidentialClientCredential, InteractiveCredential, PublicClientCredential
from .msal_transport_adapter import MsalTransportAdapter, MsalTransportResponse
Expand All @@ -60,6 +61,7 @@ def _scopes_to_resource(*scopes):
"AuthCodeRedirectServer",
"AadClientCertificate",
"CertificateCredentialBase",
"ClientSecretCredentialBase",
"ConfidentialClientCredential",
"get_default_authority",
"InteractiveCredential",
Expand Down
12 changes: 10 additions & 2 deletions sdk/identity/azure-identity/azure/identity/_internal/aad_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

class AadClient(AadClientBase):
def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs):
# type: (str, str, Sequence[str], Optional[str], **Any) -> AccessToken
# type: (Sequence[str], str, str, Optional[str], **Any) -> AccessToken
request = self._get_auth_code_request(
scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret
)
Expand All @@ -50,8 +50,16 @@ def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs):
content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)
return self._process_response(response=content, scopes=scopes, now=now)

def obtain_token_by_client_secret(self, scopes, secret, **kwargs):
# type: (Sequence[str], str, **Any) -> AccessToken
request = self._get_client_secret_request(scopes, secret)
now = int(time.time())
response = self._pipeline.run(request, stream=False, **kwargs)
content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)
return self._process_response(response=content, scopes=scopes, now=now)

def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
# type: (str, Sequence[str], **Any) -> AccessToken
# type: (Sequence[str], str, **Any) -> AccessToken
request = self._get_refresh_token_request(scopes, refresh_token)
now = int(time.time())
response = self._pipeline.run(request, stream=False, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_
def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs):
pass

@abc.abstractmethod
def obtain_token_by_client_secret(self, scopes, secret, **kwargs):
pass

@abc.abstractmethod
def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
pass
Expand Down Expand Up @@ -131,6 +135,19 @@ def _get_client_certificate_request(self, scopes, certificate):
)
return request

def _get_client_secret_request(self, scopes, secret):
# type: (Sequence[str], str) -> HttpRequest
data = {
"client_id": self._client_id,
"client_secret": secret,
"grant_type": "client_credentials",
"scope": " ".join(scopes),
}
request = HttpRequest(
"POST", self._token_endpoint, headers={"Content-Type": "application/x-www-form-urlencoded"}, data=data
)
return request

def _get_jwt_assertion(self, certificate):
# type: (AadClientCertificate) -> str
now = int(time.time())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,33 @@
# Licensed under the MIT License.
# ------------------------------------
import abc
from typing import TYPE_CHECKING

try:
ABC = abc.ABC
except AttributeError: # Python 2.7, abc exists, but not ABC
except AttributeError: # Python 2.7
ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Any, Optional, Union

# pylint:disable=unused-import,ungrouped-imports
from typing import Any

class ClientSecretCredentialBase(object):
"""Sans I/O base for client secret credentials"""

def __init__(self, tenant_id, client_id, secret, **kwargs): # pylint:disable=unused-argument
class ClientSecretCredentialBase(ABC):
def __init__(self, tenant_id, client_id, client_secret, **kwargs):
# type: (str, str, str, **Any) -> None
if not client_id:
raise ValueError("client_id should be the id of an Azure Active Directory application")
if not secret:
if not client_secret:
raise ValueError("secret should be an Azure Active Directory application's client secret")
if not tenant_id:
raise ValueError(
"tenant_id should be an Azure Active Directory tenant's id (also called its 'directory id')"
)
self._form_data = {"client_id": client_id, "client_secret": secret, "grant_type": "client_credentials"}
super(ClientSecretCredentialBase, self).__init__()

self._client = self._get_auth_client(tenant_id, client_id, **kwargs)
self._secret = client_secret

@abc.abstractmethod
def _get_auth_client(self, tenant_id, client_id, **kwargs):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from typing import TYPE_CHECKING

from .base import AsyncCredentialBase
from .._authn_client import AsyncAuthnClient
from ..._base import ClientSecretCredentialBase
from .._internal import AadClient
from ..._internal import ClientSecretCredentialBase

if TYPE_CHECKING:
from typing import Any
from azure.core.credentials import AccessToken


class ClientSecretCredential(ClientSecretCredentialBase, AsyncCredentialBase):
class ClientSecretCredential(AsyncCredentialBase, ClientSecretCredentialBase):
"""Authenticates as a service principal using a client ID and client secret.
:param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID.
Expand All @@ -25,10 +25,6 @@ class ClientSecretCredential(ClientSecretCredentialBase, AsyncCredentialBase):
defines authorities for other clouds.
"""

def __init__(self, tenant_id: str, client_id: str, client_secret: str, **kwargs: "Any") -> None:
super(ClientSecretCredential, self).__init__(tenant_id, client_id, client_secret, **kwargs)
self._client = AsyncAuthnClient(tenant=tenant_id, **kwargs)

async def __aenter__(self):
await self._client.__aenter__()
return self
Expand All @@ -38,7 +34,7 @@ async def close(self):

await self._client.__aexit__()

async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument
async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
"""Asynchronously request an access token for `scopes`.
.. note:: This method is called by Azure SDK clients. It isn't intended for use in application code.
Expand All @@ -52,8 +48,10 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
if not scopes:
raise ValueError("'get_token' requires at least one scope")

token = self._client.get_cached_token(scopes)
token = self._client.get_cached_access_token(scopes)
if not token:
data = dict(self._form_data, scope=" ".join(scopes))
token = await self._client.request_token(scopes, form_data=data)
return token # type: ignore
token = await self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
return token

def _get_auth_client(self, tenant_id, client_id, **kwargs):
return AadClient(tenant_id, client_id, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ async def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs
content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)
return self._process_response(response=content, scopes=scopes, now=now)

async def obtain_token_by_client_secret(
self, scopes: "Sequence[str]", secret: str, **kwargs: "Any"
) -> "AccessToken":
request = self._get_client_secret_request(scopes, secret)
now = int(time.time())
response = await self._pipeline.run(request, **kwargs)
content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)
return self._process_response(response=content, scopes=scopes, now=now)

async def obtain_token_by_refresh_token(
self, scopes: "Sequence[str]", refresh_token: str, **kwargs: "Any"
) -> "AccessToken":
Expand Down
24 changes: 24 additions & 0 deletions sdk/identity/azure-identity/tests/test_aad_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,30 @@ def send(request, **_):
assert transport.send.call_count == 1


def test_client_secret():
tenant_id = "tenant-id"
client_id = "client-id"
scope = "scope"
secret = "refresh-token"
access_token = "***"

def send(request, **_):
assert request.data["client_id"] == client_id
assert request.data["client_secret"] == secret
assert request.data["grant_type"] == "client_credentials"
assert request.data["scope"] == scope

return mock_response(json_payload={"access_token": access_token, "expires_in": 42})

transport = Mock(send=Mock(wraps=send))

client = AadClient(tenant_id, client_id, transport=transport)
token = client.obtain_token_by_client_secret(scopes=(scope,), secret=secret)

assert token.token == access_token
assert transport.send.call_count == 1


def test_refresh_token():
tenant_id = "tenant-id"
client_id = "client-id"
Expand Down
24 changes: 24 additions & 0 deletions sdk/identity/azure-identity/tests/test_aad_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,30 @@ async def send(request, **_):
assert transport.send.call_count == 1


async def test_client_secret():
tenant_id = "tenant-id"
client_id = "client-id"
scope = "scope"
secret = "refresh-token"
access_token = "***"

async def send(request, **_):
assert request.data["client_id"] == client_id
assert request.data["client_secret"] == secret
assert request.data["grant_type"] == "client_credentials"
assert request.data["scope"] == scope

return mock_response(json_payload={"access_token": access_token, "expires_in": 42})

transport = Mock(send=Mock(wraps=send))

client = AadClient(tenant_id, client_id, transport=transport)
token = await client.obtain_token_by_client_secret(scopes=(scope,), secret=secret)

assert token.token == access_token
assert transport.send.call_count == 1


async def test_refresh_token():
tenant_id = "tenant-id"
client_id = "client-id"
Expand Down

0 comments on commit 0ec1601

Please sign in to comment.