Skip to content

Commit

Permalink
Expose methods for closing async credential transport sessions (#9090)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Jan 13, 2020
1 parent d8a9ffd commit 8b40aae
Show file tree
Hide file tree
Showing 34 changed files with 544 additions and 105 deletions.
23 changes: 23 additions & 0 deletions sdk/core/azure-core/azure/core/credentials_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Any
from typing_extensions import Protocol
from .credentials import AccessToken

class AsyncTokenCredential(Protocol):
async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
pass

async def close(self) -> None:
pass

async def __aenter__(self):
pass

async def __aexit__(self, exc_type, exc_value, traceback) -> None:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,10 @@ class _BearerTokenCredentialPolicyBase(object):
:param str scopes: Lets you specify the type of access needed.
"""

def __init__(self, credential, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (TokenCredential, *str, Mapping[str, Any]) -> None
def __init__(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (*str, **Any) -> None
super(_BearerTokenCredentialPolicyBase, self).__init__()
self._scopes = scopes
self._credential = credential
self._token = None # type: Optional[AccessToken]

@staticmethod
Expand Down Expand Up @@ -69,6 +68,11 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPo
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""

def __init__(self, credential, *scopes, **kwargs):
# type: (TokenCredential, *str, **Any) -> None
self._credential = credential
super(BearerTokenCredentialPolicy, self).__init__(*scopes, **kwargs)

def on_request(self, request):
# type: (PipelineRequest) -> None
"""Adds a bearer token Authorization header to request and sends request to next policy.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,33 @@
# license information.
# -------------------------------------------------------------------------
import threading
from typing import TYPE_CHECKING

from azure.core.pipeline import PipelineRequest
from azure.core.pipeline.policies import SansIOHTTPPolicy
from azure.core.pipeline.policies._authentication import _BearerTokenCredentialPolicyBase

if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Any
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.pipeline import PipelineRequest


class AsyncBearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPolicy):
# pylint:disable=too-few-public-methods
"""Adds a bearer token Authorization header to requests.
:param credential: The credential.
:type credential: ~azure.core.credentials.TokenCredential
:type credential: ~azure.core.credentials_async.AsyncTokenCredential
:param str scopes: Lets you specify the type of access needed.
"""

def __init__(self, credential, *scopes, **kwargs):
super().__init__(credential, *scopes, **kwargs)
def __init__(self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: "Any") -> None:
self._credential = credential
self._lock = threading.Lock()
super().__init__(*scopes, **kwargs)

async def on_request(self, request: PipelineRequest):
async def on_request(self, request: "PipelineRequest"):
"""Adds a bearer token Authorization header to request and sends request to next policy.
:param request: The pipeline request object to be modified.
Expand Down
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

- All credential pipelines include `ProxyPolicy`
([#8945](https://github.com/Azure/azure-sdk-for-python/pull/8945))
- Async credentials are async context managers and have an async `close` method
([#9090](https://github.com/Azure/azure-sdk-for-python/pull/9090))


## 1.1.0 (2019-11-27)
Expand Down
18 changes: 18 additions & 0 deletions sdk/identity/azure-identity/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,24 @@ async transport, such as [aiohttp](https://pypi.org/project/aiohttp/). See
[azure-core documentation](../../core/azure-core/README.md#transport)
for more information.

Async credentials should be closed when they're no longer needed. Each async
credential is an async context manager and defines an async `close` method. For
example:

```py
from azure.identity.aio import DefaultAzureCredential

# call close when the credential is no longer needed
credential = DefaultAzureCredential()
...
await credential.close()

# alternatively, use the credential as an async context manager
credential = DefaultAzureCredential()
async with credential:
...
```

This example demonstrates authenticating the asynchronous `SecretClient` from
[azure-keyvault-secrets][azure_keyvault_secrets] with an asynchronous
credential.
Expand Down
26 changes: 14 additions & 12 deletions sdk/identity/azure-identity/azure/identity/_credentials/chained.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@
from azure.core.credentials import AccessToken, TokenCredential


def _get_error_message(history):
attempts = []
for credential, error in history:
if error:
attempts.append("{}: {}".format(credential.__class__.__name__, error))
else:
attempts.append(credential.__class__.__name__)
return """No credential in this chain provided a token.
Attempted credentials:\n\t{}""".format(
"\n\t".join(attempts)
)


class ChainedTokenCredential(object):
"""A sequence of credentials that is itself a credential.
Expand Down Expand Up @@ -48,16 +61,5 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
history.append((credential, ex.message))
except Exception as ex: # pylint: disable=broad-except
history.append((credential, str(ex)))
error_message = self._get_error_message(history)
error_message = _get_error_message(history)
raise ClientAuthenticationError(message=error_message)

@staticmethod
def _get_error_message(history):
attempts = []
for credential, error in history:
if error:
attempts.append("{}: {}".format(credential.__class__.__name__, error))
else:
attempts.append(credential.__class__.__name__)
return """No credential in this chain provided a token.
Attempted credentials:\n\t{}""".format("\n\t".join(attempts))
10 changes: 10 additions & 0 deletions sdk/identity/azure-identity/azure/identity/aio/_authn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ def __init__(
self._pipeline = AsyncPipeline(transport=transport, policies=policies)
super().__init__(**kwargs)

async def __aenter__(self):
await self._pipeline.__aenter__()
return self

async def __aexit__(self, *args):
await self.close()

async def close(self) -> None:
await self._pipeline.__aexit__()

async def request_token(
self,
scopes: "Iterable[str]",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import abc


class AsyncCredentialBase(abc.ABC):
@abc.abstractmethod
async def close(self):
pass

async def __aenter__(self):
return self

async def __aexit__(self, *args):
await self.close()

@abc.abstractmethod
async def get_token(self, *scopes, **kwargs):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,40 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import asyncio
from typing import TYPE_CHECKING

from azure.core.exceptions import ClientAuthenticationError
from ... import ChainedTokenCredential as SyncChainedTokenCredential
from .base import AsyncCredentialBase
from ..._credentials.chained import _get_error_message

if TYPE_CHECKING:
from typing import Any
from azure.core.credentials import AccessToken
from azure.core.credentials_async import AsyncTokenCredential


class ChainedTokenCredential(SyncChainedTokenCredential):
class ChainedTokenCredential(AsyncCredentialBase):
"""A sequence of credentials that is itself a credential.
Its :func:`get_token` method calls ``get_token`` on each credential in the sequence, in order, returning the first
valid token received.
:param credentials: credential instances to form the chain
:type credentials: :class:`azure.core.credentials.TokenCredential`
:type credentials: :class:`azure.core.credentials.AsyncTokenCredential`
"""

async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument
def __init__(self, *credentials: "AsyncTokenCredential") -> None:
if not credentials:
raise ValueError("at least one credential is required")
self.credentials = credentials

async def close(self):
"""Close the transport sessions of all credentials in the chain."""

await asyncio.gather(*(credential.close() for credential in self.credentials))

async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
"""Asynchronously request a token from each credential, in order, returning the first token received.
If no credential provides a token, raises :class:`azure.core.exceptions.ClientAuthenticationError`
Expand All @@ -41,5 +54,5 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
history.append((credential, ex.message))
except Exception as ex: # pylint: disable=broad-except
history.append((credential, str(ex)))
error_message = self._get_error_message(history)
error_message = _get_error_message(history)
raise ClientAuthenticationError(message=error_message)
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# ------------------------------------
from typing import TYPE_CHECKING

from .base import AsyncCredentialBase
from .._authn_client import AsyncAuthnClient
from ..._base import ClientSecretCredentialBase, CertificateCredentialBase

Expand All @@ -12,7 +13,7 @@
from azure.core.credentials import AccessToken


class ClientSecretCredential(ClientSecretCredentialBase):
class ClientSecretCredential(ClientSecretCredentialBase, AsyncCredentialBase):
"""Authenticates as a service principal using a client ID and client secret.
:param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID.
Expand All @@ -28,6 +29,15 @@ def __init__(self, tenant_id: str, client_id: str, client_secret: str, **kwargs:
super(ClientSecretCredential, self).__init__(tenant_id, client_id, client_secret, **kwargs)
self._client = AsyncAuthnClient(tenant=tenant_id, **kwargs)

async def __aenter__(self):
await self._client.__aenter__()
return self

async def close(self):
"""Close the credential's transport session."""

await self._client.__aexit__()

async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument
"""Asynchronously request an access token for `scopes`.
Expand All @@ -44,7 +54,7 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
return token # type: ignore


class CertificateCredential(CertificateCredentialBase):
class CertificateCredential(CertificateCredentialBase, AsyncCredentialBase):
"""Authenticates as a service principal using a certificate.
:param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID.
Expand All @@ -57,6 +67,15 @@ class CertificateCredential(CertificateCredentialBase):
defines authorities for other clouds.
"""

async def __aenter__(self):
await self._client.__aenter__()
return self

async def close(self):
"""Close the credential's transport session."""

await self._client.__aexit__()

async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument
"""Asynchronously request an access token for `scopes`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from azure.core.exceptions import ClientAuthenticationError
from ..._constants import EnvironmentVariables
from .client_credential import CertificateCredential, ClientSecretCredential
from .base import AsyncCredentialBase

if TYPE_CHECKING:
from typing import Any, Optional, Union
from azure.core.credentials import AccessToken


class EnvironmentCredential:
class EnvironmentCredential(AsyncCredentialBase):
"""A credential configured by environment variables.
This credential is capable of authenticating as a service principal using a client secret or a certificate, or as
Expand Down Expand Up @@ -50,6 +51,17 @@ def __init__(self, **kwargs: "Any") -> None:
**kwargs
)

async def __aenter__(self):
if self._credential:
await self._credential.__aenter__()
return self

async def close(self):
"""Close the credential's transport session."""

if self._credential:
await self._credential.__aexit__()

async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
"""Asynchronously request an access token for `scopes`.
Expand Down
Loading

0 comments on commit 8b40aae

Please sign in to comment.