Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose methods for closing async credential transport sessions #9090

Merged
merged 19 commits into from
Jan 13, 2020
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions sdk/core/azure-core/azure/core/credentials_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Any
from typing_extensions import Protocol
from .credentials import AccessToken

class AsyncTokenCredential(Protocol):
async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
pass

async def close(self) -> None:
pass

async def __aenter__(self):
pass

async def __aexit__(self, exc_type, exc_value, traceback) -> None:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,10 @@ class _BearerTokenCredentialPolicyBase(object):
:param str scopes: Lets you specify the type of access needed.
"""

def __init__(self, credential, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (TokenCredential, *str, Mapping[str, Any]) -> None
def __init__(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (*str, **Any) -> None
super(_BearerTokenCredentialPolicyBase, self).__init__()
self._scopes = scopes
self._credential = credential
self._token = None # type: Optional[AccessToken]

@staticmethod
Expand Down Expand Up @@ -69,6 +68,11 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPo
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""

def __init__(self, credential, *scopes, **kwargs):
# type: (TokenCredential, *str, **Any) -> None
self._credential = credential
super(BearerTokenCredentialPolicy, self).__init__(*scopes, **kwargs)

def on_request(self, request):
# type: (PipelineRequest) -> None
"""Adds a bearer token Authorization header to request and sends request to next policy.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,33 @@
# license information.
# -------------------------------------------------------------------------
import threading
from typing import TYPE_CHECKING

from azure.core.pipeline import PipelineRequest
from azure.core.pipeline.policies import SansIOHTTPPolicy
from azure.core.pipeline.policies._authentication import _BearerTokenCredentialPolicyBase

if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Any
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.pipeline import PipelineRequest


class AsyncBearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPolicy):
# pylint:disable=too-few-public-methods
"""Adds a bearer token Authorization header to requests.

:param credential: The credential.
:type credential: ~azure.core.credentials.TokenCredential
:type credential: ~azure.core.credentials_async.AsyncTokenCredential
:param str scopes: Lets you specify the type of access needed.
"""

def __init__(self, credential, *scopes, **kwargs):
super().__init__(credential, *scopes, **kwargs)
def __init__(self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: "Any") -> None:
self._credential = credential
self._lock = threading.Lock()
super().__init__(*scopes, **kwargs)

async def on_request(self, request: PipelineRequest):
async def on_request(self, request: "PipelineRequest"):
"""Adds a bearer token Authorization header to request and sends request to next policy.

:param request: The pipeline request object to be modified.
Expand Down
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

- All credential pipelines include `ProxyPolicy`
([#8945](https://github.com/Azure/azure-sdk-for-python/pull/8945))
- Async credentials are async context managers and have an async `close` method
([#9090](https://github.com/Azure/azure-sdk-for-python/pull/9090))


## 1.1.0 (2019-11-27)
Expand Down
18 changes: 18 additions & 0 deletions sdk/identity/azure-identity/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,24 @@ async transport, such as [aiohttp](https://pypi.org/project/aiohttp/). See
[azure-core documentation](../../core/azure-core/README.md#transport)
for more information.

Async credentials should be closed when they're no longer needed. Each async
credential is an async context manager and defines an async `close` method. For
example:

```py
from azure.identity.aio import DefaultAzureCredential

# call close when the credential is no longer needed
credential = DefaultAzureCredential()
...
await credential.close()

# alternatively, use the credential as an async context manager
credential = DefaultAzureCredential()
async with credential:
...
```

This example demonstrates authenticating the asynchronous `SecretClient` from
[azure-keyvault-secrets][azure_keyvault_secrets] with an asynchronous
credential.
Expand Down
10 changes: 10 additions & 0 deletions sdk/identity/azure-identity/azure/identity/aio/_authn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ def __init__(
self._pipeline = AsyncPipeline(transport=transport, policies=policies)
super().__init__(**kwargs)

async def __aenter__(self):
await self._pipeline.__aenter__()
return self

async def __aexit__(self, *args):
await self.close()

async def close(self) -> None:
await self._pipeline.__aexit__()

async def request_token(
self,
scopes: "Iterable[str]",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import abc


class AsyncCredentialBase(abc.ABC):
@abc.abstractmethod
async def close(self):
pass

async def __aenter__(self):
lmazuel marked this conversation as resolved.
Show resolved Hide resolved
return self

async def __aexit__(self, *args):
await self.close()

@abc.abstractmethod
async def get_token(self, *scopes, **kwargs):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import asyncio
from typing import TYPE_CHECKING

from azure.core.exceptions import ClientAuthenticationError
from ... import ChainedTokenCredential as SyncChainedTokenCredential
from .base import AsyncCredentialBase

if TYPE_CHECKING:
from typing import Any
from azure.core.credentials import AccessToken


class ChainedTokenCredential(SyncChainedTokenCredential):
class ChainedTokenCredential(SyncChainedTokenCredential, AsyncCredentialBase):
"""A sequence of credentials that is itself a credential.

Its :func:`get_token` method calls ``get_token`` on each credential in the sequence, in order, returning the first
Expand All @@ -22,6 +24,11 @@ class ChainedTokenCredential(SyncChainedTokenCredential):
:type credentials: :class:`azure.core.credentials.TokenCredential`
"""

async def close(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a concern here: do we expect customer to "async enter" all the credentials in the chain, or should we have a aenter here that loop thourgh all of them and enter them?
Can I see a sample of usage of this one?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what I want to enable:

credential = DefaultAzureCredential()
client = FooServiceClient(credential)
# ... time passes, many useful service requests are authorized ...
credential.close()

I think close is the important API. I don't expect anyone to "enter" or "open" a credential. I have aenter doing nothing here because the credential doesn't know which members of its chain will send requests, and transports will open sessions as needed. (At least, our current transport implementations will. Perhaps we should make explicit who's expected to open a transport.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you implement __aenter__/__aexit__ (e.g. you are an async context manager), then you are strongly signalling to users that using async with is general goodness, but you can do the closing yourself if you so see fit.

If we don't want to give an example of intended use with async with, then why is it an async context manager?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be consistent with everything else in the SDK that wraps async transport by exposing __aenter__, __aexit__, and close. This PR does add an async with example to the README, and it's okay to use credentials that way. Every other credential's __aenter__ invokes its transport's __aenter__.

The awkwardness for ChainedTokenCredential.__aenter__ is that if it opens sessions for N credentials, N-1 of them may never be used, at some cost dependent on the HTTP client's implementation. These sessions will all be closed by __aexit__, but I thought it unnecessary to open them given that our async transports will do so as needed.

"""Close the transport sessions of all credentials in the chain."""

await asyncio.gather(*(credential.close() for credential in self.credentials))

async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument
"""Asynchronously request a token from each credential, in order, returning the first token received.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# ------------------------------------
from typing import TYPE_CHECKING

from .base import AsyncCredentialBase
from .._authn_client import AsyncAuthnClient
from ..._base import ClientSecretCredentialBase, CertificateCredentialBase

Expand All @@ -12,7 +13,7 @@
from azure.core.credentials import AccessToken


class ClientSecretCredential(ClientSecretCredentialBase):
class ClientSecretCredential(ClientSecretCredentialBase, AsyncCredentialBase):
"""Authenticates as a service principal using a client ID and client secret.

:param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID.
Expand All @@ -28,6 +29,15 @@ def __init__(self, tenant_id: str, client_id: str, client_secret: str, **kwargs:
super(ClientSecretCredential, self).__init__(tenant_id, client_id, client_secret, **kwargs)
self._client = AsyncAuthnClient(tenant=tenant_id, **kwargs)

async def __aenter__(self):
await self._client.__aenter__()
return self

async def close(self):
"""Close the credential's transport session."""

await self._client.__aexit__()

async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument
"""Asynchronously request an access token for `scopes`.

Expand All @@ -44,7 +54,7 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
return token # type: ignore


class CertificateCredential(CertificateCredentialBase):
class CertificateCredential(CertificateCredentialBase, AsyncCredentialBase):
"""Authenticates as a service principal using a certificate.

:param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID.
Expand All @@ -57,6 +67,15 @@ class CertificateCredential(CertificateCredentialBase):
defines authorities for other clouds.
"""

async def __aenter__(self):
await self._client.__aenter__()
return self

async def close(self):
"""Close the credential's transport session."""

await self._client.__aexit__()

async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument
"""Asynchronously request an access token for `scopes`.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from azure.core.exceptions import ClientAuthenticationError
from ..._constants import EnvironmentVariables
from .client_credential import CertificateCredential, ClientSecretCredential
from .base import AsyncCredentialBase

if TYPE_CHECKING:
from typing import Any, Optional, Union
from azure.core.credentials import AccessToken


class EnvironmentCredential:
class EnvironmentCredential(AsyncCredentialBase):
"""A credential configured by environment variables.

This credential is capable of authenticating as a service principal using a client secret or a certificate, or as
Expand Down Expand Up @@ -50,6 +51,17 @@ def __init__(self, **kwargs: "Any") -> None:
**kwargs
)

async def __aenter__(self):
if self._credential:
await self._credential.__aenter__()
return self

async def close(self):
"""Close the credential's transport session."""

if self._credential:
await self._credential.__aexit__()

async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
"""Asynchronously request an access token for `scopes`.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import abc
import os
from typing import TYPE_CHECKING

Expand All @@ -10,6 +11,7 @@
from azure.core.pipeline.policies import AsyncRetryPolicy

from azure.identity._credentials.managed_identity import _ManagedIdentityBase
from .base import AsyncCredentialBase
from .._authn_client import AsyncAuthnClient
from ..._constants import Endpoints, EnvironmentVariables

Expand Down Expand Up @@ -37,6 +39,15 @@ def __new__(cls, *args, **kwargs):
def __init__(self, **kwargs: "Any") -> None:
pass

async def __aenter__(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is getting a bit more odd the more methods you add to the class. I know I asked before, but I can't remember the answer - why is this a class with a __new__ and not just a function that returns the correct credential type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To hide the implementations while providing the right surface for documentation and code completion. It could instead hold an instance of the appropriate credential and wrap its methods... #9302

pass

async def __aexit__(self, *args):
pass

async def close(self):
"""Close the credential's transport session."""

async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument
"""Asynchronously request an access token for `scopes`.

Expand All @@ -49,10 +60,23 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
return AccessToken()


class _AsyncManagedIdentityBase(_ManagedIdentityBase):
class _AsyncManagedIdentityBase(_ManagedIdentityBase, AsyncCredentialBase):
def __init__(self, endpoint: str, **kwargs: "Any") -> None:
super().__init__(endpoint=endpoint, client_cls=AsyncAuthnClient, **kwargs)

async def __aenter__(self):
await self._client.__aenter__()
return self

async def close(self):
"""Close the credential's transport session."""

await self._client.__aexit__()

@abc.abstractmethod
async def get_token(self, *scopes, **kwargs):
pass

@staticmethod
def _create_config(**kwargs: "Any") -> "Configuration":
"""Build a default configuration for the credential's HTTP pipeline."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,31 @@
from ..._internal.shared_token_cache import NO_TOKEN, SharedTokenCacheBase
from .._internal.aad_client import AadClient
from .._internal.exception_wrapper import wrap_exceptions
from .base import AsyncCredentialBase

if TYPE_CHECKING:
from typing import Any
from azure.core.credentials import AccessToken
from ..._internal.aad_client import AadClientBase


class SharedTokenCacheCredential(SharedTokenCacheBase):
class SharedTokenCacheCredential(SharedTokenCacheBase, AsyncCredentialBase):
"""Authenticates using tokens in the local cache shared between Microsoft applications.

:param str username:
Username (typically an email address) of the user to authenticate as. This is required because the local cache
may contain tokens for multiple identities.
"""

async def __aenter__(self):
await self._client.__aenter__()
return self

xiangyan99 marked this conversation as resolved.
Show resolved Hide resolved
async def close(self):
"""Close the credential's transport session."""

await self._client.__aexit__()

@wrap_exceptions
async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument
"""Get an access token for `scopes` from the shared cache.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,19 @@


class AadClient(AadClientBase):
# pylint:disable=arguments-differ
async def __aenter__(self):
await self._client.session.__aenter__()
return self

async def __aexit__(self, *args):
await self.close()

async def close(self) -> None:
"""Close the client's transport session."""

await self._client.session.__aexit__()

# pylint:disable=arguments-differ
def obtain_token_by_authorization_code(
self, *args: "Any", loop: "asyncio.AbstractEventLoop" = None, **kwargs: "Any"
) -> "AccessToken":
Expand Down
Loading