Skip to content

Commit

Permalink
Synchronous device code credential (#6464)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Aug 1, 2019
1 parent a08c25a commit b0bd437
Show file tree
Hide file tree
Showing 9 changed files with 182 additions and 19 deletions.
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/azure/identity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
CertificateCredential,
ChainedTokenCredential,
ClientSecretCredential,
DeviceCodeCredential,
EnvironmentCredential,
ManagedIdentityCredential,
UsernamePasswordCredential,
Expand Down Expand Up @@ -35,6 +36,7 @@ def __init__(self, **kwargs):
"ChainedTokenCredential",
"ClientSecretCredential",
"DefaultAzureCredential",
"DeviceCodeCredential",
"EnvironmentCredential",
"InteractiveBrowserCredential",
"ManagedIdentityCredential",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
# Licensed under the MIT License.
# ------------------------------------
from .auth_code_redirect_handler import AuthCodeRedirectServer
from .exception_wrapper import wrap_exceptions
from .msal_credentials import ConfidentialClientCredential, PublicClientCredential
from .msal_transport_adapter import MsalTransportAdapter, MsalTransportResponse
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import functools

from six import raise_from

from azure.core.exceptions import ClientAuthenticationError


def wrap_exceptions(fn):
"""Prevents leaking exceptions defined outside azure-core by raising ClientAuthenticationError from them."""

@functools.wraps(fn)
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except ClientAuthenticationError:
raise
except Exception as ex:
auth_error = ClientAuthenticationError(message="Authentication failed: {}".format(ex))
raise_from(auth_error, ex)

return wrapper
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError

from .exception_wrapper import wrap_exceptions
from .msal_transport_adapter import MsalTransportAdapter

try:
Expand Down Expand Up @@ -75,6 +76,7 @@ def _create_app(self, cls):
class ConfidentialClientCredential(MsalCredential):
"""Wraps an MSAL ConfidentialClientApplication with the TokenCredential API"""

@wrap_exceptions
def get_token(self, *scopes):
# type: (str) -> AccessToken

Expand Down
3 changes: 2 additions & 1 deletion sdk/identity/azure-identity/azure/identity/browser_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError

from ._internal import AuthCodeRedirectServer, ConfidentialClientCredential
from ._internal import AuthCodeRedirectServer, ConfidentialClientCredential, wrap_exceptions


class InteractiveBrowserCredential(ConfidentialClientCredential):
Expand Down Expand Up @@ -48,6 +48,7 @@ def __init__(self, client_id, client_secret, **kwargs):
client_id=client_id, client_credential=client_secret, authority=authority, **kwargs
)

@wrap_exceptions
def get_token(self, *scopes):
# type: (str) -> AccessToken
"""
Expand Down
91 changes: 87 additions & 4 deletions sdk/identity/azure-identity/azure/identity/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from ._authn_client import AuthnClient
from ._base import ClientSecretCredentialBase, CertificateCredentialBase
from ._internal import PublicClientCredential
from ._internal import PublicClientCredential, wrap_exceptions
from ._managed_identity import ImdsCredential, MsiCredential
from .constants import Endpoints, EnvironmentVariables

Expand All @@ -26,8 +26,9 @@

if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Any, Dict, Mapping, Optional, Union
from typing import Any, Callable, Dict, Mapping, Optional, Union
from azure.core.credentials import TokenCredential

EnvironmentCredentialTypes = Union["CertificateCredential", "ClientSecretCredential", "UsernamePasswordCredential"]

# pylint:disable=too-few-public-methods
Expand Down Expand Up @@ -249,6 +250,86 @@ def _get_error_message(history):
return "No valid token received. {}".format(". ".join(attempts))


class DeviceCodeCredential(PublicClientCredential):
"""
Authenticates users through the device code flow. When ``get_token`` is called, this credential acquires a
verification URL and code from Azure Active Directory. A user must browse to the URL, enter the code, and
authenticate with Directory. If the user authenticates successfully, the credential receives an access token.
This credential doesn't cache tokens--each ``get_token`` call begins a new authentication flow.
For more information about the device code flow, see Azure Active Directory documentation:
https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-device-code
:param str client_id: the application's ID
:param prompt_callback: (optional) A callback enabling control of how authentication instructions are presented.
If not provided, the credential will print instructions to stdout.
:type prompt_callback: A callable accepting arguments (``verification_uri``, ``user_code``, ``expires_in``):
- ``verification_uri`` (str) the URL the user must visit
- ``user_code`` (str) the code the user must enter there
- ``expires_in`` (int) the number of seconds the code will be valid
**Keyword arguments:**
- *tenant (str)* - tenant ID or a domain associated with a tenant. If not provided, the credential defaults to the
'organizations' tenant, which supports only Azure Active Directory work or school accounts.
- *timeout (int)* - seconds to wait for the user to authenticate. Defaults to the validity period of the device code
as set by Azure Active Directory, which also prevails when ``timeout`` is longer.
"""

def __init__(self, client_id, prompt_callback=None, **kwargs):
# type: (str, Optional[Callable[[str, str], None]], Any) -> None
self._timeout = kwargs.pop("timeout", None) # type: Optional[int]
self._prompt_callback = prompt_callback
super(DeviceCodeCredential, self).__init__(client_id=client_id, **kwargs)

@wrap_exceptions
def get_token(self, *scopes):
# type (*str) -> AccessToken
"""
Request an access token for `scopes`. This credential won't cache the token. Each call begins a new
authentication flow.
:param str scopes: desired scopes for the token
:rtype: :class:`azure.core.credentials.AccessToken`
:raises: :class:`azure.core.exceptions.ClientAuthenticationError`
"""

# MSAL requires scopes be a list
scopes = list(scopes) # type: ignore
now = int(time.time())

app = self._get_app()
flow = app.initiate_device_flow(scopes)
if "error" in flow:
raise ClientAuthenticationError(
message="Couldn't begin authentication: {}".format(flow.get("error_description") or flow.get("error"))
)

if self._prompt_callback:
self._prompt_callback(flow["verification_uri"], flow["user_code"], flow["expires_in"])
else:
print(flow["message"])

if self._timeout is not None and self._timeout < flow["expires_in"]:
deadline = now + self._timeout
result = app.acquire_token_by_device_flow(flow, exit_condition=lambda flow: time.time() > deadline)
else:
result = app.acquire_token_by_device_flow(flow)

if "access_token" not in result:
if result.get("error") == "authorization_pending":
message = "Timed out waiting for user to authenticate"
else:
message = "Authentication failed: {}".format(result.get("error_description") or result.get("error"))
raise ClientAuthenticationError(message=message)

token = AccessToken(result["access_token"], now + int(result["expires_in"]))
return token


class UsernamePasswordCredential(PublicClientCredential):
"""
Authenticates a user with a username and password. In general, Microsoft doesn't recommend this kind of
Expand All @@ -267,8 +348,9 @@ class UsernamePasswordCredential(PublicClientCredential):
**Keyword arguments:**
*tenant (str)* - a tenant ID or a domain associated with a tenant. If not provided, the credential defaults to the
'organizations' tenant.
- **tenant (str)** - a tenant ID or a domain associated with a tenant. If not provided, defaults to the
'organizations' tenant.
"""

def __init__(self, client_id, username, password, **kwargs):
Expand All @@ -277,6 +359,7 @@ def __init__(self, client_id, username, password, **kwargs):
self._username = username
self._password = password

@wrap_exceptions
def get_token(self, *scopes):
# type (*str) -> AccessToken
"""
Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/azure-identity/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,6 @@
"azure",
]
),
install_requires=["azure-core<2.0.0,>=1.0.0b1", "cryptography>=2.1.4", "msal~=0.4.1"],
install_requires=["azure-core<2.0.0,>=1.0.0b1", "cryptography>=2.1.4", "msal~=0.4.1", "six>=1.6"],
extras_require={":python_version<'3.0'": ["azure-nspkg"], ":python_version<'3.5'": ["typing"]},
)
69 changes: 62 additions & 7 deletions sdk/identity/azure-identity/tests/test_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,21 @@
except ImportError: # python < 3.3
from mock import Mock, patch

import pytest
from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
from azure.identity import (
ChainedTokenCredential,
ClientSecretCredential,
DefaultAzureCredential,
DeviceCodeCredential,
EnvironmentCredential,
ManagedIdentityCredential,
ChainedTokenCredential,
InteractiveBrowserCredential,
UsernamePasswordCredential,
)
from azure.identity._managed_identity import ImdsCredential
from azure.identity.constants import EnvironmentVariables
import pytest

from helpers import mock_response, Request, validating_transport

Expand Down Expand Up @@ -123,11 +124,6 @@ def test_client_secret_environment_credential(monkeypatch):
assert token.token == access_token


def test_environment_credential_error():
with pytest.raises(ClientAuthenticationError):
EnvironmentCredential().get_token("scope")


def test_credential_chain_error_message():
def raise_authn_error(message):
raise ClientAuthenticationError(message)
Expand Down Expand Up @@ -244,6 +240,65 @@ def test_default_credential():
DefaultAzureCredential()


def test_device_code_credential():
expected_token = "access-token"
user_code = "user-code"
verification_uri = "verification-uri"
expires_in = 42

transport = validating_transport(
requests=[Request()] * 3, # not validating requests because they're formed by MSAL
responses=[
# expected requests: discover tenant, start device code flow, poll for completion
mock_response(json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}),
mock_response(
json_payload={"device_code": "_", "user_code": user_code, "verification_uri": verification_uri, "expires_in": expires_in}
),
mock_response(
json_payload={
"access_token": expected_token,
"expires_in": expires_in,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "_",
}
),
],
)

callback = Mock()
credential = DeviceCodeCredential(
client_id="_", prompt_callback=callback, transport=transport, instance_discovery=False
)

token = credential.get_token("scope")
assert token.token == expected_token

# prompt_callback should have been called as documented
assert callback.call_count == 1
assert callback.call_args[0] == (verification_uri, user_code, expires_in)


def test_device_code_credential_timeout():
transport = validating_transport(
requests=[Request()] * 3, # not validating requests because they're formed by MSAL
responses=[
# expected requests: discover tenant, start device code flow, poll for completion
mock_response(json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}),
mock_response(json_payload={"device_code": "_", "user_code": "_", "verification_uri": "_"}),
mock_response(json_payload={"error": "authorization_pending"}),
],
)

credential = DeviceCodeCredential(
client_id="_", prompt_callback=Mock(), transport=transport, timeout=0.1, instance_discovery=False
)

with pytest.raises(ClientAuthenticationError) as ex:
credential.get_token("scope")
assert "timed out" in ex.value.message.lower()


@patch("azure.identity.browser_auth.webbrowser.open", lambda _: None) # prevent the credential opening a browser
def test_interactive_credential():
oauth_state = "state"
Expand Down
6 changes: 0 additions & 6 deletions sdk/identity/azure-identity/tests/test_identity_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,6 @@ async def test_client_secret_environment_credential(monkeypatch):
assert token.token == access_token


@pytest.mark.asyncio
async def test_environment_credential_error():
with pytest.raises(ClientAuthenticationError):
await EnvironmentCredential().get_token("scope")


@pytest.mark.asyncio
async def test_credential_chain_error_message():
def raise_authn_error(message):
Expand Down

0 comments on commit b0bd437

Please sign in to comment.