diff --git a/sdk/identity/azure-identity/azure/identity/__init__.py b/sdk/identity/azure-identity/azure/identity/__init__.py index fd23d51c8e32..1df14dd617f5 100644 --- a/sdk/identity/azure-identity/azure/identity/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/__init__.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from .browser_auth import InteractiveBrowserCredential from .credentials import ( CertificateCredential, ChainedTokenCredential, @@ -35,6 +36,7 @@ def __init__(self, **kwargs): "ClientSecretCredential", "DefaultAzureCredential", "EnvironmentCredential", + "InteractiveBrowserCredential", "ManagedIdentityCredential", "UsernamePasswordCredential", ] diff --git a/sdk/identity/azure-identity/azure/identity/_internal/__init__.py b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py index cfdf935b801c..e7583139395b 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py @@ -2,5 +2,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from .auth_code_redirect_handler import AuthCodeRedirectServer from .msal_credentials import ConfidentialClientCredential, PublicClientCredential from .msal_transport_adapter import MsalTransportAdapter, MsalTransportResponse diff --git a/sdk/identity/azure-identity/azure/identity/_internal/auth_code_redirect_handler.py b/sdk/identity/azure-identity/azure/identity/_internal/auth_code_redirect_handler.py new file mode 100644 index 000000000000..caaa9519033c --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/_internal/auth_code_redirect_handler.py @@ -0,0 +1,75 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +try: + from typing import TYPE_CHECKING +except ImportError: + TYPE_CHECKING = False + +if TYPE_CHECKING: + from typing import Any, Mapping, Optional + +try: + from http.server import HTTPServer, BaseHTTPRequestHandler +except ImportError: + from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler # type: ignore + +try: + from urllib.parse import parse_qs +except ImportError: + from urlparse import parse_qs # type: ignore + + +class AuthCodeRedirectHandler(BaseHTTPRequestHandler): + """HTTP request handler to capture the authentication server's response. + Largely from the Azure CLI: https://github.com/Azure/azure-cli/blob/dev/src/azure-cli-core/azure/cli/core/_profile.py + """ + + def do_GET(self): + if self.path.endswith("/favicon.ico"): # deal with legacy IE + self.send_response(204) + return + + query = self.path.split("?", 1)[-1] + query = parse_qs(query, keep_blank_values=True) + self.server.query_params = query + + self.send_response(200) + self.send_header("Content-Type", "text/html") + self.end_headers() + + self.wfile.write(b"Authentication complete. You can close this window.") + + def log_message(self, format, *args): # pylint: disable=redefined-builtin,unused-argument,no-self-use + pass # this prevents server dumping messages to stdout + + +class AuthCodeRedirectServer(HTTPServer): + """HTTP server that listens on localhost for the redirect request following an authorization code authentication""" + + query_params = {} # type: Mapping[str, Any] + + def __init__(self, port, timeout): + # type: (int, int) -> None + super(AuthCodeRedirectServer, self).__init__(("localhost", port), AuthCodeRedirectHandler) + self.timeout = timeout + + def wait_for_redirect(self): + # type: () -> Mapping[str, Any] + while not self.query_params: + try: + self.handle_request() + except ValueError: + # socket has been closed, probably by handle_timeout + break + + # ensure the underlying socket is closed (a no-op when the socket is already closed) + self.server_close() + + # if we timed out, this returns an empty dict + return self.query_params + + def handle_timeout(self): + """Break the request-handling loop by tearing down the server""" + self.server_close() diff --git a/sdk/identity/azure-identity/azure/identity/browser_auth.py b/sdk/identity/azure-identity/azure/identity/browser_auth.py new file mode 100644 index 000000000000..82460513092d --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/browser_auth.py @@ -0,0 +1,121 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import socket +import time +import uuid +import webbrowser + +try: + from typing import TYPE_CHECKING +except ImportError: + TYPE_CHECKING = False + +if TYPE_CHECKING: + from typing import Any, List, Mapping + +from azure.core.credentials import AccessToken +from azure.core.exceptions import ClientAuthenticationError + +from ._internal import AuthCodeRedirectServer, ConfidentialClientCredential + + +class InteractiveBrowserCredential(ConfidentialClientCredential): + """ + Authenticates a user through the authorization code flow. This is an interactive flow: ``get_token`` opens a + browser to a login URL provided by Azure Active Directory, and waits for the user to authenticate there. + + Azure Active Directory documentation describes the authorization code flow in more detail: + https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-protocols-oauth-code + + :param str client_id: the application's client ID + :param str secret: one of the application's client secrets + + **Keyword arguments:** + + *tenant (str)* - a tenant ID or a domain associated with a tenant. If not provided, the credential defaults to the + 'organizations' tenant, which can authenticate work or school accounts. + *timeout (str)* - seconds to wait for the user to complete authentication. Defaults to 300 (5 minutes). + """ + + def __init__(self, client_id, client_secret, **kwargs): + # type: (str, str, Any) -> None + self._timeout = kwargs.pop("timeout", 300) + self._server_class = kwargs.pop("server_class", AuthCodeRedirectServer) # facilitate mocking + authority = "https://login.microsoftonline.com/" + kwargs.pop("tenant", "organizations") + super(InteractiveBrowserCredential, self).__init__( + client_id=client_id, client_credential=client_secret, authority=authority, **kwargs + ) + + def get_token(self, *scopes): + # type: (str) -> AccessToken + """ + Request an access token for `scopes`. This will open a browser to a login page and listen on localhost for a + request indicating authentication has completed. + + :param str scopes: desired scopes for the token + :rtype: :class:`azure.core.credentials.AccessToken` + :raises: :class:`azure.core.exceptions.ClientAuthenticationError` + """ + + # start an HTTP server on localhost to receive the redirect + for port in range(8400, 9000): + try: + server = self._server_class(port, timeout=self._timeout) + redirect_uri = "http://localhost:{}".format(port) + break + except socket.error: + continue # keep looking for an open port + + if not redirect_uri: + raise ClientAuthenticationError(message="Couldn't start an HTTP server on localhost") + + # get the url the user must visit to authenticate + scopes = list(scopes) # type: ignore + request_state = str(uuid.uuid4()) + app = self._get_app() + auth_url = app.get_authorization_request_url(scopes, redirect_uri=redirect_uri, state=request_state) + + # open browser to that url + webbrowser.open(auth_url) + + # block until the server times out or receives the post-authentication redirect + response = server.wait_for_redirect() + if not response: + raise ClientAuthenticationError( + message="Timed out after waiting {} seconds for the user to authenticate".format(self._timeout) + ) + + # redeem the authorization code for a token + code = self._parse_response(request_state, response) + now = int(time.time()) + result = app.acquire_token_by_authorization_code(code, scopes=scopes, redirect_uri=redirect_uri) + + if "access_token" not in result: + raise ClientAuthenticationError(message="Authentication failed: {}".format(result.get("error_description"))) + + return AccessToken(result["access_token"], now + int(result["expires_in"])) + + def _parse_response(self, request_state, response): + # type: (str, Mapping[str, Any]) -> List[str] + """ + Validates ``response`` and returns the authorization code it contains, if authentication succeeded. Raises + :class:`azure.core.exceptions.ClientAuthenticationError`, if authentication failed or ``response`` is malformed. + """ + + if "error" in response: + message = "Authentication failed: {}".format(response.get("error_description") or response["error"]) + raise ClientAuthenticationError(message=message) + if "code" not in response: + # a response with no error or code is malformed; we don't know what to do with it + message = "Authentication server didn't send an authorization code" + raise ClientAuthenticationError(message=message) + + # response must include the state sent in the auth request + if "state" not in response: + raise ClientAuthenticationError(message="Authentication response doesn't include OAuth state") + if response["state"][0] != request_state: + raise ClientAuthenticationError(message="Authentication response's OAuth state doesn't match the request's") + + return response["code"] diff --git a/sdk/identity/azure-identity/tests/test_identity.py b/sdk/identity/azure-identity/tests/test_identity.py index 4e7fb13b4f6b..3aaa54a1f28f 100644 --- a/sdk/identity/azure-identity/tests/test_identity.py +++ b/sdk/identity/azure-identity/tests/test_identity.py @@ -2,15 +2,16 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import functools import json import os import time import uuid try: - from unittest.mock import Mock + from unittest.mock import Mock, patch except ImportError: # python < 3.3 - from mock import Mock + from mock import Mock, patch import pytest from azure.core.credentials import AccessToken @@ -21,6 +22,7 @@ EnvironmentCredential, ManagedIdentityCredential, ChainedTokenCredential, + InteractiveBrowserCredential, UsernamePasswordCredential, ) from azure.identity._managed_identity import ImdsCredential @@ -242,6 +244,73 @@ def test_default_credential(): DefaultAzureCredential() +@patch("azure.identity.browser_auth.webbrowser.open", lambda _: None) # prevent the credential opening a browser +def test_interactive_credential(): + oauth_state = "state" + expected_token = "access-token" + + transport = validating_transport( + requests=[Request()] * 2, # not validating requests because they're formed by MSAL + responses=[ + # expecting tenant discovery then a token request + mock_response(json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}), + mock_response( + json_payload={ + "access_token": expected_token, + "expires_in": 42, + "token_type": "Bearer", + "ext_expires_in": 42, + } + ), + ], + ) + + # mock local server fakes successful authentication by immediately returning a well-formed response + auth_code_response = {"code": "authorization-code", "state": [oauth_state]} + server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response)) + + credential = InteractiveBrowserCredential( + client_id="guid", + client_secret="secret", + server_class=server_class, + transport=transport, + instance_discovery=False, # kwargs are passed to MSAL; this one prevents an AAD verification request + ) + + # ensure the request beginning the flow has a known state value + with patch("azure.identity.browser_auth.uuid.uuid4", lambda: oauth_state): + token = credential.get_token("scope") + assert token.token == expected_token + + +@patch("azure.identity.browser_auth.webbrowser.open", lambda _: None) # prevent the credential opening a browser +def test_interactive_credential_timeout(): + # mock transport handles MSAL's tenant discovery + transport = Mock( + send=lambda _, **__: mock_response( + json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"} + ) + ) + + # mock local server blocks long enough to exceed the timeout + timeout = 1 + server_instance = Mock(wait_for_redirect=functools.partial(time.sleep, timeout + 1)) + server_class = Mock(return_value=server_instance) + + credential = InteractiveBrowserCredential( + client_id="guid", + client_secret="secret", + server_class=server_class, + timeout=timeout, + transport=transport, + instance_discovery=False, # kwargs are passed to MSAL; this one prevents an AAD verification request + ) + + with pytest.raises(ClientAuthenticationError) as ex: + credential.get_token("scope") + assert "timed out" in ex.value.message.lower() + + def test_username_password_credential(): expected_token = "access-token" transport = validating_transport(