Skip to content

Commit

Permalink
Add EnvVar AZURE_AUTHORITY_HOST (#10357)
Browse files Browse the repository at this point in the history
  • Loading branch information
v-xuto authored Apr 6, 2020
1 parent 3639cdb commit 19ae038
Show file tree
Hide file tree
Showing 23 changed files with 318 additions and 142 deletions.
4 changes: 4 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ state. ([#10243](https://github.com/Azure/azure-sdk-for-python/issues/10243))
cache is available but contains ambiguous or insufficient information. This
causes `ChainedTokenCredential` to correctly try the next credential in the
chain. ([#10631](https://github.com/Azure/azure-sdk-for-python/issues/10631))
- The host of the Active Directory endpoint credentials should use can be set
in the environment variable `AZURE_AUTHORITY_HOST`. See
`azure.identity.KnownAuthorities` for a list of common values.
([#8094](https://github.com/Azure/azure-sdk-for-python/issues/8094))


## 1.3.1 (2020-03-30)
Expand Down
5 changes: 3 additions & 2 deletions sdk/identity/azure-identity/azure/identity/_authn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
UserAgentPolicy,
)
from azure.core.pipeline.transport import RequestsTransport, HttpRequest
from ._constants import AZURE_CLI_CLIENT_ID, KnownAuthorities
from ._constants import AZURE_CLI_CLIENT_ID
from ._internal import get_default_authority
from ._internal.user_agent import USER_AGENT

try:
Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(self, endpoint=None, authority=None, tenant=None, **kwargs): # pyl
else:
if not tenant:
raise ValueError("'tenant' is required")
authority = authority or KnownAuthorities.AZURE_PUBLIC_CLOUD
authority = authority or get_default_authority()
self._auth_url = "https://" + "/".join((authority.strip("/"), tenant.strip("/"), "oauth2/v2.0/token"))
self._cache = kwargs.get("cache") or TokenCache() # type: TokenCache

Expand Down
1 change: 1 addition & 0 deletions sdk/identity/azure-identity/azure/identity/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class EnvironmentVariables:

MSI_ENDPOINT = "MSI_ENDPOINT"
MSI_SECRET = "MSI_SECRET"
AZURE_AUTHORITY_HOST = "AZURE_AUTHORITY_HOST"


class Endpoints:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import logging
import os

from .._constants import EnvironmentVariables, KnownAuthorities
from .._constants import EnvironmentVariables
from .._internal import get_default_authority
from .browser import InteractiveBrowserCredential
from .chained import ChainedTokenCredential
from .environment import EnvironmentCredential
Expand Down Expand Up @@ -61,7 +62,7 @@ class DefaultAzureCredential(ChainedTokenCredential):
"""

def __init__(self, **kwargs):
authority = kwargs.pop("authority", None) or KnownAuthorities.AZURE_PUBLIC_CLOUD
authority = kwargs.pop("authority", None) or get_default_authority()

shared_cache_username = kwargs.pop("shared_cache_username", os.environ.get(EnvironmentVariables.AZURE_USERNAME))
shared_cache_tenant_id = kwargs.pop(
Expand Down
10 changes: 10 additions & 0 deletions sdk/identity/azure-identity/azure/identity/_internal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import os

from .._constants import EnvironmentVariables, KnownAuthorities


def get_default_authority():
return os.environ.get(EnvironmentVariables.AZURE_AUTHORITY_HOST, KnownAuthorities.AZURE_PUBLIC_CLOUD)


# pylint:disable=wrong-import-position
from .aad_client import AadClient
from .aad_client_base import AadClientBase
from .auth_code_redirect_handler import AuthCodeRedirectServer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
from .._constants import KnownAuthorities
from . import get_default_authority

try:
ABC = abc.ABC
Expand All @@ -34,7 +34,7 @@ class AadClientBase(ABC):

def __init__(self, tenant_id, client_id, cache=None, **kwargs):
# type: (str, str, Optional[TokenCache], **Any) -> None
authority = kwargs.pop("authority", KnownAuthorities.AZURE_PUBLIC_CLOUD)
authority = kwargs.pop("authority", None) or get_default_authority()
if authority[-1] == "/":
authority = authority[:-1]
token_endpoint = "https://" + "/".join((authority, tenant_id, "oauth2/v2.0/token"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from .exception_wrapper import wrap_exceptions
from .msal_transport_adapter import MsalTransportAdapter
from .._constants import KnownAuthorities
from .._internal import get_default_authority

try:
ABC = abc.ABC
Expand All @@ -37,7 +37,7 @@ class MsalCredential(ABC):
def __init__(self, client_id, client_credential=None, **kwargs):
# type: (str, Optional[Union[str, Mapping[str, str]]], **Any) -> None
tenant_id = kwargs.pop("tenant_id", "organizations")
authority = kwargs.pop("authority", KnownAuthorities.AZURE_PUBLIC_CLOUD)
authority = kwargs.pop("authority", None) or get_default_authority()
self._base_url = "https://" + "/".join((authority.strip("/"), tenant_id.strip("/")))
self._client_credential = client_credential
self._client_id = client_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .. import CredentialUnavailableError
from .._constants import KnownAuthorities
from .._internal import get_default_authority

try:
ABC = abc.ABC
Expand Down Expand Up @@ -86,7 +87,7 @@ class SharedTokenCacheBase(ABC):
def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument
# type: (Optional[str], **Any) -> None

self._authority = kwargs.pop("authority", None) or KnownAuthorities.AZURE_PUBLIC_CLOUD
self._authority = kwargs.pop("authority", None) or get_default_authority()
self._authority_aliases = KNOWN_ALIASES.get(self._authority) or frozenset((self._authority,))
self._username = username
self._tenant_id = kwargs.pop("tenant_id", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import os
from typing import TYPE_CHECKING

from ..._constants import EnvironmentVariables, KnownAuthorities
from ..._constants import EnvironmentVariables
from ..._internal import get_default_authority
from .azure_cli import AzureCliCredential
from .chained import ChainedTokenCredential
from .environment import EnvironmentCredential
Expand Down Expand Up @@ -52,7 +53,7 @@ class DefaultAzureCredential(ChainedTokenCredential):
"""

def __init__(self, **kwargs):
authority = kwargs.pop("authority", None) or KnownAuthorities.AZURE_PUBLIC_CLOUD
authority = kwargs.pop("authority", None) or get_default_authority()

shared_cache_username = kwargs.pop("shared_cache_username", os.environ.get(EnvironmentVariables.AZURE_USERNAME))
shared_cache_tenant_id = kwargs.pop(
Expand Down
11 changes: 9 additions & 2 deletions sdk/identity/azure-identity/tests/test_aad_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
import functools

from azure.core.exceptions import ClientAuthenticationError
from azure.identity._constants import EnvironmentVariables
from azure.identity._internal.aad_client import AadClient
import pytest
from six.moves.urllib_parse import urlparse

from helpers import mock_response

try:
from unittest.mock import Mock
from unittest.mock import Mock, patch
except ImportError: # python < 3.3
from mock import Mock # type: ignore
from mock import Mock, patch # type: ignore


class MockClient(AadClient):
Expand Down Expand Up @@ -113,3 +114,9 @@ def send(request, **_):

client.obtain_token_by_authorization_code("code", "uri", "scope")
client.obtain_token_by_refresh_token("refresh token", "scope")

# authority can be configured via environment variable
with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True):
client = AadClient(tenant_id=tenant_id, client_id="client id", transport=Mock(send=send))
client.obtain_token_by_authorization_code("code", "uri", "scope")
client.obtain_token_by_refresh_token("refresh token", "scope")
9 changes: 8 additions & 1 deletion sdk/identity/azure-identity/tests/test_aad_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from unittest.mock import Mock
from unittest.mock import Mock, patch
from urllib.parse import urlparse

from azure.identity._constants import EnvironmentVariables
from azure.identity.aio._internal.aad_client import AadClient
import pytest

Expand Down Expand Up @@ -57,3 +58,9 @@ async def send(request, **_):

await client.obtain_token_by_authorization_code("code", "uri", "scope")
await client.obtain_token_by_refresh_token("refresh token", "scope")

# authority can be configured via environment variable
with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True):
client = AadClient(tenant_id=tenant, client_id="client id", transport=Mock(send=send))
await client.obtain_token_by_authorization_code("code", "uri", "scope")
await client.obtain_token_by_refresh_token("refresh token", "scope")
10 changes: 9 additions & 1 deletion sdk/identity/azure-identity/tests/test_authn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from azure.core.credentials import AccessToken
from azure.identity._authn_client import AuthnClient
from azure.identity._constants import EnvironmentVariables
from six.moves.urllib_parse import urlparse
from helpers import mock_response

Expand Down Expand Up @@ -205,7 +206,7 @@ def mock_send(request, **kwargs):


def test_request_url():
authority = "authority.com"
authority = "localhost"
tenant = "expected_tenant"

def validate_url(url):
Expand All @@ -222,3 +223,10 @@ def mock_send(request, **kwargs):
client.request_token(("scope",))
request = client.get_refresh_token_grant_request({"secret": "***"}, "scope")
validate_url(request.url)

# authority can be configured via environment variable
with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True):
client = AuthnClient(tenant=tenant, transport=Mock(send=mock_send))
client.request_token(("scope",))
request = client.get_refresh_token_grant_request({"secret": "***"}, "scope")
validate_url(request.url)
8 changes: 7 additions & 1 deletion sdk/identity/azure-identity/tests/test_authn_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
# Licensed under the MIT License.
# ------------------------------------
import asyncio
from unittest.mock import Mock
from unittest.mock import Mock, patch
from urllib.parse import urlparse

import pytest
from azure.identity._constants import EnvironmentVariables
from azure.identity.aio._authn_client import AsyncAuthnClient

from helpers import mock_response
Expand All @@ -27,3 +28,8 @@ def mock_send(request, **kwargs):

client = AsyncAuthnClient(tenant=tenant, transport=Mock(send=wrap_in_future(mock_send)), authority=authority)
await client.request_token(("scope",))

# authority can be configured via environment variable
with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True):
client = AsyncAuthnClient(tenant=tenant, transport=Mock(send=wrap_in_future(mock_send)))
await client.request_token(("scope",))
13 changes: 11 additions & 2 deletions sdk/identity/azure-identity/tests/test_certificate_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy
from azure.identity import CertificateCredential
from azure.identity._constants import EnvironmentVariables
from azure.identity._internal.user_agent import USER_AGENT
from cryptography import x509
from cryptography.hazmat.backends import default_backend
Expand All @@ -18,9 +19,9 @@
from helpers import build_aad_response, urlsafeb64_decode, mock_response, Request, validating_transport

try:
from unittest.mock import Mock
from unittest.mock import Mock, patch
except ImportError: # python < 3.3
from mock import Mock # type: ignore
from mock import Mock, patch # type: ignore

CERT_PATH = os.path.join(os.path.dirname(__file__), "certificate.pem")
CERT_WITH_PASSWORD_PATH = os.path.join(os.path.dirname(__file__), "certificate-with-password.pem")
Expand Down Expand Up @@ -84,6 +85,14 @@ def mock_send(request, **kwargs):
token = cred.get_token("scope")
assert token.token == access_token

# authority can be configured via environment variable
with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True):
credential = CertificateCredential(
tenant_id, "client-id", cert_path, password=cert_password, transport=Mock(send=mock_send)
)
credential.get_token("scope")
assert token.token == access_token


@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS)
def test_request_body(cert_path, cert_password):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from unittest.mock import Mock
from unittest.mock import Mock, patch
from urllib.parse import urlparse

from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy
from azure.identity._constants import EnvironmentVariables
from azure.identity._internal.user_agent import USER_AGENT
from azure.identity.aio import CertificateCredential

Expand Down Expand Up @@ -98,6 +99,14 @@ async def mock_send(request, **kwargs):
token = await cred.get_token("scope")
assert token.token == access_token

# authority can be configured via environment variable
with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True):
credential = CertificateCredential(
tenant_id, "client-id", cert_path, password=cert_password, transport=Mock(send=mock_send)
)
await credential.get_token("scope")
assert token.token == access_token


@pytest.mark.asyncio
@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS)
Expand Down
32 changes: 30 additions & 2 deletions sdk/identity/azure-identity/tests/test_client_secret_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@
from azure.core.credentials import AccessToken
from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy
from azure.identity import ClientSecretCredential
from azure.identity._constants import EnvironmentVariables
from azure.identity._internal.user_agent import USER_AGENT
import pytest
from six.moves.urllib_parse import urlparse

from helpers import build_aad_response, mock_response, Request, validating_transport

try:
from unittest.mock import Mock
from unittest.mock import Mock, patch
except ImportError: # python < 3.3
from mock import Mock # type: ignore
from mock import Mock, patch # type: ignore


def test_no_scopes():
Expand Down Expand Up @@ -78,6 +80,32 @@ def test_client_secret_credential():
assert token.token == access_token


def test_request_url():
authority = "localhost"
tenant_id = "expected_tenant"
access_token = "***"

def mock_send(request, **kwargs):
parsed = urlparse(request.url)
assert parsed.scheme == "https"
assert parsed.netloc == authority
assert parsed.path.startswith("/" + tenant_id)

return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": access_token})

credential = ClientSecretCredential(
tenant_id, "client-id", "secret", transport=Mock(send=mock_send), authority=authority
)
token = credential.get_token("scope")
assert token.token == access_token

# authority can be configured via environment variable
with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True):
credential = ClientSecretCredential(tenant_id, "client-id", "secret", transport=Mock(send=mock_send))
credential.get_token("scope")
assert token.token == access_token


def test_cache():
expired = "this token's expired"
now = int(time.time())
Expand Down
Loading

0 comments on commit 19ae038

Please sign in to comment.