Skip to content

Commit

Permalink
auth code credential is an async context manager too
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell committed Jan 13, 2020
1 parent 8aa045d commit 4f33abf
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import TYPE_CHECKING

from azure.core.exceptions import ClientAuthenticationError
from .base import AsyncCredentialBase
from .._internal import AadClient

if TYPE_CHECKING:
Expand All @@ -14,7 +15,7 @@
from azure.core.credentials import AccessToken


class AuthorizationCodeCredential(object):
class AuthorizationCodeCredential(AsyncCredentialBase):
"""Authenticates by redeeming an authorization code previously obtained from Azure Active Directory.
See https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-auth-code-flow for more information
Expand All @@ -31,13 +32,19 @@ class AuthorizationCodeCredential(object):
:keyword str client_secret: One of the application's client secrets. Required only for web apps and web APIs.
"""

async def __aenter__(self):
if self._client:
await self._client.__aenter__()
return self

async def close(self):
"""Close the credential's transport session."""

if self._client:
await self._client.__aexit__()

def __init__(
self,
tenant_id: str,
client_id: str,
authorization_code: str,
redirect_uri: str,
**kwargs: "Any"
self, tenant_id: str, client_id: str, authorization_code: str, redirect_uri: str, **kwargs: "Any"
) -> None:
self._authorization_code = authorization_code # type: Optional[str]
self._client_id = client_id
Expand Down
28 changes: 27 additions & 1 deletion sdk/identity/azure-identity/tests/test_auth_code_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pytest

from helpers import build_aad_response, mock_response, Request
from helpers_async import async_validating_transport, wrap_in_future
from helpers_async import async_validating_transport, AsyncMockTransport, wrap_in_future


@pytest.mark.asyncio
Expand All @@ -32,6 +32,32 @@ async def send(*_, **__):
assert policy.on_request.called


@pytest.mark.asyncio
async def test_close():
transport = AsyncMockTransport()
credential = AuthorizationCodeCredential(
"tenant-id", "client-id", "auth-code", "http://localhost", transport=transport
)

await credential.close()

assert transport.__aexit__.call_count == 1


@pytest.mark.asyncio
async def test_context_manager():
transport = AsyncMockTransport()
credential = AuthorizationCodeCredential(
"tenant-id", "client-id", "auth-code", "http://localhost", transport=transport
)

async with credential:
assert transport.__aenter__.call_count == 1

assert transport.__aenter__.call_count == 1
assert transport.__aexit__.call_count == 1


@pytest.mark.asyncio
async def test_user_agent():
transport = async_validating_transport(
Expand Down

0 comments on commit 4f33abf

Please sign in to comment.