Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use MSAL's custom transport API #11892

Merged
merged 7 commits into from
Jun 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Release History

## 1.4.0b6 (Unreleased)
- Upgraded minimum `msal` version to 1.3.0
- The async `AzureCliCredential` correctly invokes `/bin/sh`
([#12048](https://github.com/Azure/azure-sdk-for-python/issues/12048))

Expand All @@ -18,14 +19,14 @@
identity by its client ID, continue using the `client_id` argument. To
specify an identity by any other ID, use the `identity_config` argument,
for example: `ManagedIdentityCredential(identity_config={"object_id": ".."})`
([#10989](https://github.com/Azure/azure-sdk-for-python/issues/10989))
([#10989](https://github.com/Azure/azure-sdk-for-python/issues/10989))
- `CertificateCredential` and `ClientSecretCredential` can optionally store
access tokens they acquire in a persistent cache. To enable this, construct
the credential with `enable_persistent_cache=True`. On Linux, the persistent
cache requires libsecret and `pygobject`. If these are unavailable or
unusable (e.g. in an SSH session), loading the persistent cache will raise an
error. You may optionally configure the credential to fall back to an
unencrypted cache by constructing it with keyword argument
unencrypted cache by constructing it with keyword argument
`allow_unencrypted_cache=True`.
([#11347](https://github.com/Azure/azure-sdk-for-python/issues/11347))
- `AzureCliCredential` raises `CredentialUnavailableError` when no user is
Expand Down Expand Up @@ -66,7 +67,7 @@

## 1.4.0b3 (2020-05-04)
- `EnvironmentCredential` correctly initializes `UsernamePasswordCredential`
with the value of `AZURE_TENANT_ID`
with the value of `AZURE_TENANT_ID`
([#11127](https://github.com/Azure/azure-sdk-for-python/pull/11127))
- Values for the constructor keyword argument `authority` and
`AZURE_AUTHORITY_HOST` may optionally specify an "https" scheme. For example,
Expand All @@ -86,7 +87,7 @@ with the value of `AZURE_TENANT_ID`
- `enable_persistent_cache=True` configures these credentials to use a
persistent cache on supported platforms (in this release, Windows only).
By default they cache in memory only.
- Now `DefaultAzureCredential` can authenticate with the identity signed in to
- Now `DefaultAzureCredential` can authenticate with the identity signed in to
Visual Studio Code's Azure extension.
([#10472](https://github.com/Azure/azure-sdk-for-python/issues/10472))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(self, client_id, username, password, **kwargs):
def _request_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> dict
app = self._get_app()
with self._adapter:
return app.acquire_token_by_username_password(
username=self._username, password=self._password, scopes=list(scopes)
)
return app.acquire_token_by_username_password(
username=self._username, password=self._password, scopes=list(scopes)
)
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def get_default_authority():
from .certificate_credential_base import CertificateCredentialBase
from .client_secret_credential_base import ClientSecretCredentialBase
from .exception_wrapper import wrap_exceptions
from .msal_credentials import ConfidentialClientCredential, InteractiveCredential, PublicClientCredential
from .msal_transport_adapter import MsalTransportAdapter, MsalTransportResponse
from .msal_credentials import InteractiveCredential, PublicClientCredential


def _scopes_to_resource(*scopes):
Expand All @@ -62,11 +61,8 @@ def _scopes_to_resource(*scopes):
"AadClientCertificate",
"CertificateCredentialBase",
"ClientSecretCredentialBase",
"ConfidentialClientCredential",
"get_default_authority",
"InteractiveCredential",
"MsalTransportAdapter",
"MsalTransportResponse",
"normalize_authority",
"PublicClientCredential",
"wrap_exceptions",
Expand Down
137 changes: 137 additions & 0 deletions sdk/identity/azure-identity/azure/identity/_internal/msal_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import six

from azure.core.configuration import Configuration
from azure.core.exceptions import ClientAuthenticationError
from azure.core.pipeline import Pipeline
from azure.core.pipeline.policies import (
ContentDecodePolicy,
DistributedTracingPolicy,
HttpLoggingPolicy,
NetworkTraceLoggingPolicy,
ProxyPolicy,
RetryPolicy,
UserAgentPolicy,
)
from azure.core.pipeline.transport import HttpRequest, RequestsTransport

from .user_agent import USER_AGENT

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Dict, List, Optional, Union
from azure.core.pipeline import PipelineResponse
from azure.core.pipeline.policies import HTTPPolicy, SansIOHTTPPolicy
from azure.core.pipeline.transport import HttpTransport

PolicyList = List[Union[HTTPPolicy, SansIOHTTPPolicy]]
RequestData = Union[Dict[str, str], str]


class MsalResponse(object):
"""Wraps HttpResponse according to msal.oauth2cli.http"""

def __init__(self, response):
# type: (PipelineResponse) -> None
self._response = response

@property
def status_code(self):
# type: () -> int
return self._response.http_response.status_code

@property
def text(self):
# type: () -> str
return self._response.http_response.text(encoding="utf-8")

def raise_for_status(self):
if self.status_code < 400:
return

if ContentDecodePolicy.CONTEXT_NAME in self._response.context:
content = self._response.context[ContentDecodePolicy.CONTEXT_NAME]
if "error" in content or "error_description" in content:
message = "Authentication failed: {}".format(content.get("error_description") or content.get("error"))
else:
for secret in ("access_token", "refresh_token"):
if secret in content:
content[secret] = "***"
message = 'Unexpected response from Azure Active Directory: "{}"'.format(content)
else:
message = "Unexpected response from Azure Active Directory"

raise ClientAuthenticationError(message=message, response=self._response.http_response)


class MsalClient(object):
"""Wraps Pipeline according to msal.oauth2cli.http"""

def __init__(self, **kwargs): # pylint:disable=missing-client-constructor-parameter-credential
# type: (**Any) -> None
self._pipeline = _build_pipeline(**kwargs)

def post(self, url, params=None, data=None, headers=None, **kwargs): # pylint:disable=unused-argument
# type: (str, Optional[Dict[str, str]], RequestData, Optional[Dict[str, str]], **Any) -> MsalResponse
request = HttpRequest("POST", url, headers=headers)
if params:
request.format_parameters(params)
if data:
if isinstance(data, dict):
request.headers["Content-Type"] = "application/x-www-form-urlencoded"
request.set_formdata_body(data)
elif isinstance(data, six.text_type):
body_bytes = six.ensure_binary(data)
request.set_bytes_body(body_bytes)
else:
raise ValueError('expected "data" to be text or a dict')

response = self._pipeline.run(request)
return MsalResponse(response)

def get(self, url, params=None, headers=None, **kwargs): # pylint:disable=unused-argument
# type: (str, Optional[Dict[str, str]], Optional[Dict[str, str]], **Any) -> MsalResponse
request = HttpRequest("GET", url, headers=headers)
if params:
request.format_parameters(params)
response = self._pipeline.run(request)
return MsalResponse(response)


def _create_config(**kwargs):
# type: (Any) -> Configuration
config = Configuration(**kwargs)
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
config.retry_policy = RetryPolicy(**kwargs)
config.proxy_policy = ProxyPolicy(**kwargs)
config.user_agent_policy = UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs)
return config


def _build_pipeline(config=None, policies=None, transport=None, **kwargs):
# type: (Optional[Configuration], Optional[PolicyList], Optional[HttpTransport], **Any) -> Pipeline
config = config or _create_config(**kwargs)

if policies is None: # [] is a valid policy list
policies = [
ContentDecodePolicy(),
config.user_agent_policy,
config.proxy_policy,
config.retry_policy,
config.logging_policy,
DistributedTracingPolicy(**kwargs),
HttpLoggingPolicy(**kwargs),
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need header policy?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. User agent is the only header we want to set on every request.


if not transport:
transport = RequestsTransport(**kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to let user to customize transport?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can user achieve it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pass it as a parameter to the credential. Our tests do this:

credential = InteractiveBrowserCredential(
_cache=TokenCache(),
authority=environment,
client_id=client_id,
server_class=server_class,
tenant_id=tenant_id,
transport=transport,
)


return Pipeline(transport=transport, policies=policies)
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
"""Credentials wrapping MSAL applications and delegating token acquisition and caching to them.
This entails monkeypatching MSAL's OAuth client with an adapter substituting an azure-core pipeline for Requests.
"""
import abc
import base64
import json
Expand All @@ -17,7 +14,7 @@
from azure.core.exceptions import ClientAuthenticationError

from .exception_wrapper import wrap_exceptions
from .msal_transport_adapter import MsalTransportAdapter
from .msal_client import MsalClient
from .persistent_cache import load_user_cache
from .._constants import KnownAuthorities
from .._exceptions import AuthenticationRequiredError, CredentialUnavailableError
Expand Down Expand Up @@ -102,7 +99,7 @@ def __init__(self, client_id, client_credential=None, **kwargs):
else:
self._cache = msal.TokenCache()

self._adapter = kwargs.pop("msal_adapter", None) or MsalTransportAdapter(**kwargs)
self._client = MsalClient(**kwargs)

# postpone creating the wrapped application because its initializer uses the network
self._msal_app = None # type: Optional[msal.ClientApplication]
Expand All @@ -119,53 +116,17 @@ def _get_app(self):

def _create_app(self, cls):
# type: (Type[msal.ClientApplication]) -> msal.ClientApplication
"""Creates an MSAL application, patching msal.authority to use an azure-core pipeline during tenant discovery"""

# MSAL application initializers use msal.authority to send AAD tenant discovery requests
with self._adapter:
# MSAL's "authority" is a URL e.g. https://login.microsoftonline.com/common
app = cls(
client_id=self._client_id,
client_credential=self._client_credential,
authority="{}/{}".format(self._authority, self._tenant_id),
token_cache=self._cache,
)

# monkeypatch the app to replace requests.Session with MsalTransportAdapter
app.client.session.close()
app.client.session = self._adapter
app = cls(
client_id=self._client_id,
client_credential=self._client_credential,
authority="{}/{}".format(self._authority, self._tenant_id),
token_cache=self._cache,
http_client=self._client,
)

return app


class ConfidentialClientCredential(MsalCredential):
"""Wraps an MSAL ConfidentialClientApplication with the TokenCredential API"""

@wrap_exceptions
def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (*str, **Any) -> AccessToken

# MSAL requires scopes be a list
scopes = list(scopes) # type: ignore
now = int(time.time())

# First try to get a cached access token or if a refresh token is cached, redeem it for an access token.
# Failing that, acquire a new token.
app = self._get_app()
result = app.acquire_token_silent(scopes, account=None) or app.acquire_token_for_client(scopes)

if "access_token" not in result:
raise ClientAuthenticationError(message="authentication failed: {}".format(result.get("error_description")))

return AccessToken(result["access_token"], now + int(result["expires_in"]))

def _get_app(self):
# type: () -> msal.ConfidentialClientApplication
if not self._msal_app:
self._msal_app = self._create_app(msal.ConfidentialClientApplication)
return self._msal_app


class PublicClientCredential(MsalCredential):
"""Wraps an MSAL PublicClientApplication with the TokenCredential API"""

Expand Down
Loading