Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement synchronous interactive authentication #6466

Merged
merged 4 commits into from
Jul 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/azure/identity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from .browser_auth import InteractiveBrowserCredential
from .credentials import (
CertificateCredential,
ChainedTokenCredential,
Expand Down Expand Up @@ -35,6 +36,7 @@ def __init__(self, **kwargs):
"ClientSecretCredential",
"DefaultAzureCredential",
"EnvironmentCredential",
"InteractiveBrowserCredential",
"ManagedIdentityCredential",
"UsernamePasswordCredential",
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from .auth_code_redirect_handler import AuthCodeRedirectServer
from .msal_credentials import ConfidentialClientCredential, PublicClientCredential
from .msal_transport_adapter import MsalTransportAdapter, MsalTransportResponse
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

if TYPE_CHECKING:
from typing import Any, Mapping, Optional

try:
from http.server import HTTPServer, BaseHTTPRequestHandler
except ImportError:
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler # type: ignore

try:
from urllib.parse import parse_qs
except ImportError:
from urlparse import parse_qs # type: ignore


class AuthCodeRedirectHandler(BaseHTTPRequestHandler):
"""HTTP request handler to capture the authentication server's response.
Largely from the Azure CLI: https://github.com/Azure/azure-cli/blob/dev/src/azure-cli-core/azure/cli/core/_profile.py
"""

def do_GET(self):
if self.path.endswith("/favicon.ico"): # deal with legacy IE
self.send_response(204)
return

query = self.path.split("?", 1)[-1]
query = parse_qs(query, keep_blank_values=True)
self.server.query_params = query

self.send_response(200)
self.send_header("Content-Type", "text/html")
self.end_headers()

self.wfile.write(b"Authentication complete. You can close this window.")

def log_message(self, format, *args): # pylint: disable=redefined-builtin,unused-argument,no-self-use
pass # this prevents server dumping messages to stdout


class AuthCodeRedirectServer(HTTPServer):
"""HTTP server that listens on localhost for the redirect request following an authorization code authentication"""

query_params = {} # type: Mapping[str, Any]

def __init__(self, port, timeout):
# type: (int, int) -> None
super(AuthCodeRedirectServer, self).__init__(("localhost", port), AuthCodeRedirectHandler)
self.timeout = timeout

def wait_for_redirect(self):
# type: () -> Mapping[str, Any]
while not self.query_params:
try:
self.handle_request()
except ValueError:
# socket has been closed, probably by handle_timeout
break

# ensure the underlying socket is closed (a no-op when the socket is already closed)
self.server_close()

# if we timed out, this returns an empty dict
return self.query_params

def handle_timeout(self):
"""Break the request-handling loop by tearing down the server"""
self.server_close()
121 changes: 121 additions & 0 deletions sdk/identity/azure-identity/azure/identity/browser_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import socket
import time
import uuid
import webbrowser

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

if TYPE_CHECKING:
from typing import Any, List, Mapping

from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError

from ._internal import AuthCodeRedirectServer, ConfidentialClientCredential


class InteractiveBrowserCredential(ConfidentialClientCredential):
"""
Authenticates a user through the authorization code flow. This is an interactive flow: ``get_token`` opens a
browser to a login URL provided by Azure Active Directory, and waits for the user to authenticate there.

Azure Active Directory documentation describes the authorization code flow in more detail:
https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-protocols-oauth-code

:param str client_id: the application's client ID
:param str secret: one of the application's client secrets

**Keyword arguments:**

*tenant (str)* - a tenant ID or a domain associated with a tenant. If not provided, the credential defaults to the
'organizations' tenant, which can authenticate work or school accounts.
*timeout (str)* - seconds to wait for the user to complete authentication. Defaults to 300 (5 minutes).
"""

def __init__(self, client_id, client_secret, **kwargs):
# type: (str, str, Any) -> None
self._timeout = kwargs.pop("timeout", 300)
self._server_class = kwargs.pop("server_class", AuthCodeRedirectServer) # facilitate mocking
authority = "https://login.microsoftonline.com/" + kwargs.pop("tenant", "organizations")
super(InteractiveBrowserCredential, self).__init__(
client_id=client_id, client_credential=client_secret, authority=authority, **kwargs
)

def get_token(self, *scopes):
# type: (str) -> AccessToken
"""
Request an access token for `scopes`. This will open a browser to a login page and listen on localhost for a
request indicating authentication has completed.

:param str scopes: desired scopes for the token
:rtype: :class:`azure.core.credentials.AccessToken`
:raises: :class:`azure.core.exceptions.ClientAuthenticationError`
"""

# start an HTTP server on localhost to receive the redirect
for port in range(8400, 9000):
try:
server = self._server_class(port, timeout=self._timeout)
redirect_uri = "http://localhost:{}".format(port)
break
except socket.error:
continue # keep looking for an open port

if not redirect_uri:
raise ClientAuthenticationError(message="Couldn't start an HTTP server on localhost")

# get the url the user must visit to authenticate
scopes = list(scopes) # type: ignore
request_state = str(uuid.uuid4())
app = self._get_app()
auth_url = app.get_authorization_request_url(scopes, redirect_uri=redirect_uri, state=request_state)

# open browser to that url
webbrowser.open(auth_url)

# block until the server times out or receives the post-authentication redirect
response = server.wait_for_redirect()
if not response:
raise ClientAuthenticationError(
message="Timed out after waiting {} seconds for the user to authenticate".format(self._timeout)
)

# redeem the authorization code for a token
code = self._parse_response(request_state, response)
now = int(time.time())
result = app.acquire_token_by_authorization_code(code, scopes=scopes, redirect_uri=redirect_uri)

if "access_token" not in result:
raise ClientAuthenticationError(message="Authentication failed: {}".format(result.get("error_description")))

return AccessToken(result["access_token"], now + int(result["expires_in"]))

def _parse_response(self, request_state, response):
# type: (str, Mapping[str, Any]) -> List[str]
"""
Validates ``response`` and returns the authorization code it contains, if authentication succeeded. Raises
:class:`azure.core.exceptions.ClientAuthenticationError`, if authentication failed or ``response`` is malformed.
"""

if "error" in response:
message = "Authentication failed: {}".format(response.get("error_description") or response["error"])
raise ClientAuthenticationError(message=message)
if "code" not in response:
# a response with no error or code is malformed; we don't know what to do with it
message = "Authentication server didn't send an authorization code"
raise ClientAuthenticationError(message=message)

# response must include the state sent in the auth request
if "state" not in response:
raise ClientAuthenticationError(message="Authentication response doesn't include OAuth state")
if response["state"][0] != request_state:
raise ClientAuthenticationError(message="Authentication response's OAuth state doesn't match the request's")

return response["code"]
73 changes: 71 additions & 2 deletions sdk/identity/azure-identity/tests/test_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import functools
import json
import os
import time
import uuid

try:
from unittest.mock import Mock
from unittest.mock import Mock, patch
except ImportError: # python < 3.3
from mock import Mock
from mock import Mock, patch

import pytest
from azure.core.credentials import AccessToken
Expand All @@ -21,6 +22,7 @@
EnvironmentCredential,
ManagedIdentityCredential,
ChainedTokenCredential,
InteractiveBrowserCredential,
UsernamePasswordCredential,
)
from azure.identity._managed_identity import ImdsCredential
Expand Down Expand Up @@ -242,6 +244,73 @@ def test_default_credential():
DefaultAzureCredential()


@patch("azure.identity.browser_auth.webbrowser.open", lambda _: None) # prevent the credential opening a browser
def test_interactive_credential():
oauth_state = "state"
expected_token = "access-token"

transport = validating_transport(
requests=[Request()] * 2, # not validating requests because they're formed by MSAL
responses=[
# expecting tenant discovery then a token request
mock_response(json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}),
mock_response(
json_payload={
"access_token": expected_token,
"expires_in": 42,
"token_type": "Bearer",
"ext_expires_in": 42,
}
),
],
)

# mock local server fakes successful authentication by immediately returning a well-formed response
auth_code_response = {"code": "authorization-code", "state": [oauth_state]}
server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response))

credential = InteractiveBrowserCredential(
client_id="guid",
client_secret="secret",
server_class=server_class,
transport=transport,
instance_discovery=False, # kwargs are passed to MSAL; this one prevents an AAD verification request
)

# ensure the request beginning the flow has a known state value
with patch("azure.identity.browser_auth.uuid.uuid4", lambda: oauth_state):
token = credential.get_token("scope")
assert token.token == expected_token


@patch("azure.identity.browser_auth.webbrowser.open", lambda _: None) # prevent the credential opening a browser
def test_interactive_credential_timeout():
# mock transport handles MSAL's tenant discovery
transport = Mock(
send=lambda _, **__: mock_response(
json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}
)
)

# mock local server blocks long enough to exceed the timeout
timeout = 1
server_instance = Mock(wait_for_redirect=functools.partial(time.sleep, timeout + 1))
server_class = Mock(return_value=server_instance)

credential = InteractiveBrowserCredential(
client_id="guid",
client_secret="secret",
server_class=server_class,
timeout=timeout,
transport=transport,
instance_discovery=False, # kwargs are passed to MSAL; this one prevents an AAD verification request
)

with pytest.raises(ClientAuthenticationError) as ex:
credential.get_token("scope")
assert "timed out" in ex.value.message.lower()


def test_username_password_credential():
expected_token = "access-token"
transport = validating_transport(
Expand Down