Skip to content

Commit

Permalink
Sequence -> Iterable for scopes (#12579)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Jul 17, 2020
1 parent 82fea83 commit a5a1db1
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Optional, Sequence
from typing import Any, Iterable, Optional
from azure.core.credentials import AccessToken


Expand Down Expand Up @@ -73,7 +73,7 @@ def get_token(self, *scopes, **kwargs):
return token

def _redeem_refresh_token(self, scopes, **kwargs):
# type: (Sequence[str], **Any) -> Optional[AccessToken]
# type: (Iterable[str], **Any) -> Optional[AccessToken]
for refresh_token in self._client.get_cached_refresh_tokens(scopes):
if "secret" not in refresh_token:
continue
Expand Down
6 changes: 3 additions & 3 deletions sdk/identity/azure-identity/azure/identity/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from azure.core.exceptions import ClientAuthenticationError

if TYPE_CHECKING:
from typing import Any, Optional, Sequence
from typing import Any, Iterable, Optional


class CredentialUnavailableError(ClientAuthenticationError):
Expand All @@ -18,7 +18,7 @@ class AuthenticationRequiredError(CredentialUnavailableError):
"""Interactive authentication is required to acquire a token."""

def __init__(self, scopes, message=None, error_details=None, **kwargs):
# type: (Sequence[str], Optional[str], Optional[str], **Any) -> None
# type: (Iterable[str], Optional[str], Optional[str], **Any) -> None
self._scopes = scopes
self._error_details = error_details
if not message:
Expand All @@ -27,7 +27,7 @@ def __init__(self, scopes, message=None, error_details=None, **kwargs):

@property
def scopes(self):
# type: () -> Sequence[str]
# type: () -> Iterable[str]
"""Scopes requested during the failed authentication"""
return self._scopes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, List, Optional, Sequence, Union
from typing import Any, Iterable, List, Optional, Union
from azure.core.credentials import AccessToken
from azure.core.pipeline.policies import HTTPPolicy, SansIOHTTPPolicy
from azure.core.pipeline.transport import HttpTransport
Expand All @@ -32,7 +32,7 @@

class AadClient(AadClientBase):
def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs):
# type: (Sequence[str], str, str, Optional[str], **Any) -> AccessToken
# type: (Iterable[str], str, str, Optional[str], **Any) -> AccessToken
request = self._get_auth_code_request(
scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret
)
Expand All @@ -41,21 +41,21 @@ def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_
return self._process_response(response, now)

def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs):
# type: (Sequence[str], AadClientCertificate, **Any) -> AccessToken
# type: (Iterable[str], AadClientCertificate, **Any) -> AccessToken
request = self._get_client_certificate_request(scopes, certificate)
now = int(time.time())
response = self._pipeline.run(request, stream=False, **kwargs)
return self._process_response(response, now)

def obtain_token_by_client_secret(self, scopes, secret, **kwargs):
# type: (Sequence[str], str, **Any) -> AccessToken
# type: (Iterable[str], str, **Any) -> AccessToken
request = self._get_client_secret_request(scopes, secret)
now = int(time.time())
response = self._pipeline.run(request, stream=False, **kwargs)
return self._process_response(response, now)

def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
# type: (Sequence[str], str, **Any) -> AccessToken
# type: (Iterable[str], str, **Any) -> AccessToken
request = self._get_refresh_token_request(scopes, refresh_token)
now = int(time.time())
response = self._pipeline.run(request, stream=False, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Optional, Sequence, Union
from typing import Any, Iterable, List, Optional, Union
from azure.core.pipeline import AsyncPipeline, Pipeline, PipelineResponse
from azure.core.pipeline.policies import AsyncHTTPPolicy, HTTPPolicy, SansIOHTTPPolicy
from azure.core.pipeline.transport import AsyncHttpTransport, HttpTransport
Expand All @@ -50,7 +50,7 @@ def __init__(self, tenant_id, client_id, authority=None, cache=None, **kwargs):
self._pipeline = self._build_pipeline(**kwargs)

def get_cached_access_token(self, scopes, query=None):
# type: (Sequence[str], Optional[dict]) -> Optional[AccessToken]
# type: (Iterable[str], Optional[dict]) -> Optional[AccessToken]
tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes), query=query)
for token in tokens:
expires_on = int(token["expires_on"])
Expand All @@ -59,7 +59,7 @@ def get_cached_access_token(self, scopes, query=None):
return None

def get_cached_refresh_tokens(self, scopes):
# type: (Sequence[str]) -> Sequence[dict]
# type: (Iterable[str]) -> List[dict]
"""Assumes all cached refresh tokens belong to the same user"""
return self._cache.find(TokenCache.CredentialType.REFRESH_TOKEN, target=list(scopes))

Expand Down Expand Up @@ -136,7 +136,7 @@ def _process_response(self, response, request_time):
return token

def _get_auth_code_request(self, scopes, code, redirect_uri, client_secret=None):
# type: (Sequence[str], str, str, Optional[str]) -> HttpRequest
# type: (Iterable[str], str, str, Optional[str]) -> HttpRequest
data = {
"client_id": self._client_id,
"code": code,
Expand All @@ -153,7 +153,7 @@ def _get_auth_code_request(self, scopes, code, redirect_uri, client_secret=None)
return request

def _get_client_certificate_request(self, scopes, certificate):
# type: (Sequence[str], AadClientCertificate) -> HttpRequest
# type: (Iterable[str], AadClientCertificate) -> HttpRequest
assertion = self._get_jwt_assertion(certificate)
data = {
"client_assertion": assertion,
Expand All @@ -169,7 +169,7 @@ def _get_client_certificate_request(self, scopes, certificate):
return request

def _get_client_secret_request(self, scopes, secret):
# type: (Sequence[str], str) -> HttpRequest
# type: (Iterable[str], str) -> HttpRequest
data = {
"client_id": self._client_id,
"client_secret": secret,
Expand Down Expand Up @@ -207,7 +207,7 @@ def _get_jwt_assertion(self, certificate):
return jwt_bytes.decode("utf-8")

def _get_refresh_token_request(self, scopes, refresh_token):
# type: (Sequence[str], str) -> HttpRequest
# type: (Iterable[str], str) -> HttpRequest
data = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def authenticate(self, **kwargs):
# type: (**Any) -> AuthenticationRecord
"""Interactively authenticate a user.
:keyword Sequence[str] scopes: scopes to request during authentication, such as those provided by
:keyword Iterable[str] scopes: scopes to request during authentication, such as those provided by
:func:`AuthenticationRequiredError.scopes`. If provided, successful authentication will cache an access token
for these scopes.
:rtype: ~azure.identity.AuthenticationRecord
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Iterable, List, Mapping, Optional, Sequence
from typing import Any, Iterable, List, Mapping, Optional
from .._internal import AadClientBase
from azure.identity import AuthenticationRecord

Expand Down Expand Up @@ -203,7 +203,7 @@ def _get_account(self, username=None, tenant_id=None):
raise CredentialUnavailableError(message=message)

def _get_cached_access_token(self, scopes, account):
# type: (Sequence[str], CacheItem) -> Optional[AccessToken]
# type: (Iterable[str], CacheItem) -> Optional[AccessToken]
if "home_account_id" not in account:
return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Optional, Sequence
from typing import Any, Iterable, Optional
from azure.core.credentials import AccessToken


Expand Down Expand Up @@ -88,7 +88,7 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":

return token

async def _redeem_refresh_token(self, scopes: "Sequence[str]", **kwargs: "Any") -> "Optional[AccessToken]":
async def _redeem_refresh_token(self, scopes: "Iterable[str]", **kwargs: "Any") -> "Optional[AccessToken]":
for refresh_token in self._client.get_cached_refresh_tokens(scopes):
if "secret" not in refresh_token:
continue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, List, Optional, Sequence, Union
from typing import Any, Iterable, List, Optional, Union
from azure.core.credentials import AccessToken
from azure.core.pipeline.policies import AsyncHTTPPolicy, SansIOHTTPPolicy
from azure.core.pipeline.transport import AsyncHttpTransport
Expand All @@ -45,7 +45,7 @@ async def close(self) -> None:

async def obtain_token_by_authorization_code(
self,
scopes: "Sequence[str]",
scopes: "Iterable[str]",
code: str,
redirect_uri: str,
client_secret: "Optional[str]" = None,
Expand All @@ -59,22 +59,22 @@ async def obtain_token_by_authorization_code(
return self._process_response(response, now)

async def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs):
# type: (Sequence[str], AadClientCertificate, **Any) -> AccessToken
# type: (Iterable[str], AadClientCertificate, **Any) -> AccessToken
request = self._get_client_certificate_request(scopes, certificate)
now = int(time.time())
response = await self._pipeline.run(request, stream=False, **kwargs)
return self._process_response(response, now)

async def obtain_token_by_client_secret(
self, scopes: "Sequence[str]", secret: str, **kwargs: "Any"
self, scopes: "Iterable[str]", secret: str, **kwargs: "Any"
) -> "AccessToken":
request = self._get_client_secret_request(scopes, secret)
now = int(time.time())
response = await self._pipeline.run(request, **kwargs)
return self._process_response(response, now)

async def obtain_token_by_refresh_token(
self, scopes: "Sequence[str]", refresh_token: str, **kwargs: "Any"
self, scopes: "Iterable[str]", refresh_token: str, **kwargs: "Any"
) -> "AccessToken":
request = self._get_refresh_token_request(scopes, refresh_token)
now = int(time.time())
Expand Down

0 comments on commit a5a1db1

Please sign in to comment.