Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More accurate type annotation for scopes #12579

Merged
merged 1 commit into from
Jul 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Optional, Sequence
from typing import Any, Iterable, Optional
from azure.core.credentials import AccessToken


Expand Down Expand Up @@ -73,7 +73,7 @@ def get_token(self, *scopes, **kwargs):
return token

def _redeem_refresh_token(self, scopes, **kwargs):
# type: (Sequence[str], **Any) -> Optional[AccessToken]
# type: (Iterable[str], **Any) -> Optional[AccessToken]
for refresh_token in self._client.get_cached_refresh_tokens(scopes):
if "secret" not in refresh_token:
continue
Expand Down
6 changes: 3 additions & 3 deletions sdk/identity/azure-identity/azure/identity/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from azure.core.exceptions import ClientAuthenticationError

if TYPE_CHECKING:
from typing import Any, Optional, Sequence
from typing import Any, Iterable, Optional


class CredentialUnavailableError(ClientAuthenticationError):
Expand All @@ -18,7 +18,7 @@ class AuthenticationRequiredError(CredentialUnavailableError):
"""Interactive authentication is required to acquire a token."""

def __init__(self, scopes, message=None, error_details=None, **kwargs):
# type: (Sequence[str], Optional[str], Optional[str], **Any) -> None
# type: (Iterable[str], Optional[str], Optional[str], **Any) -> None
self._scopes = scopes
self._error_details = error_details
if not message:
Expand All @@ -27,7 +27,7 @@ def __init__(self, scopes, message=None, error_details=None, **kwargs):

@property
def scopes(self):
# type: () -> Sequence[str]
# type: () -> Iterable[str]
"""Scopes requested during the failed authentication"""
return self._scopes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, List, Optional, Sequence, Union
from typing import Any, Iterable, List, Optional, Union
from azure.core.credentials import AccessToken
from azure.core.pipeline.policies import HTTPPolicy, SansIOHTTPPolicy
from azure.core.pipeline.transport import HttpTransport
Expand All @@ -32,7 +32,7 @@

class AadClient(AadClientBase):
def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs):
# type: (Sequence[str], str, str, Optional[str], **Any) -> AccessToken
# type: (Iterable[str], str, str, Optional[str], **Any) -> AccessToken
request = self._get_auth_code_request(
scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret
)
Expand All @@ -41,21 +41,21 @@ def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_
return self._process_response(response, now)

def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs):
# type: (Sequence[str], AadClientCertificate, **Any) -> AccessToken
# type: (Iterable[str], AadClientCertificate, **Any) -> AccessToken
request = self._get_client_certificate_request(scopes, certificate)
now = int(time.time())
response = self._pipeline.run(request, stream=False, **kwargs)
return self._process_response(response, now)

def obtain_token_by_client_secret(self, scopes, secret, **kwargs):
# type: (Sequence[str], str, **Any) -> AccessToken
# type: (Iterable[str], str, **Any) -> AccessToken
request = self._get_client_secret_request(scopes, secret)
now = int(time.time())
response = self._pipeline.run(request, stream=False, **kwargs)
return self._process_response(response, now)

def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
# type: (Sequence[str], str, **Any) -> AccessToken
# type: (Iterable[str], str, **Any) -> AccessToken
request = self._get_refresh_token_request(scopes, refresh_token)
now = int(time.time())
response = self._pipeline.run(request, stream=False, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Optional, Sequence, Union
from typing import Any, Iterable, List, Optional, Union
from azure.core.pipeline import AsyncPipeline, Pipeline, PipelineResponse
from azure.core.pipeline.policies import AsyncHTTPPolicy, HTTPPolicy, SansIOHTTPPolicy
from azure.core.pipeline.transport import AsyncHttpTransport, HttpTransport
Expand All @@ -50,7 +50,7 @@ def __init__(self, tenant_id, client_id, authority=None, cache=None, **kwargs):
self._pipeline = self._build_pipeline(**kwargs)

def get_cached_access_token(self, scopes, query=None):
# type: (Sequence[str], Optional[dict]) -> Optional[AccessToken]
# type: (Iterable[str], Optional[dict]) -> Optional[AccessToken]
tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes), query=query)
for token in tokens:
expires_on = int(token["expires_on"])
Expand All @@ -59,7 +59,7 @@ def get_cached_access_token(self, scopes, query=None):
return None

def get_cached_refresh_tokens(self, scopes):
# type: (Sequence[str]) -> Sequence[dict]
# type: (Iterable[str]) -> List[dict]
"""Assumes all cached refresh tokens belong to the same user"""
return self._cache.find(TokenCache.CredentialType.REFRESH_TOKEN, target=list(scopes))

Expand Down Expand Up @@ -136,7 +136,7 @@ def _process_response(self, response, request_time):
return token

def _get_auth_code_request(self, scopes, code, redirect_uri, client_secret=None):
# type: (Sequence[str], str, str, Optional[str]) -> HttpRequest
# type: (Iterable[str], str, str, Optional[str]) -> HttpRequest
data = {
"client_id": self._client_id,
"code": code,
Expand All @@ -153,7 +153,7 @@ def _get_auth_code_request(self, scopes, code, redirect_uri, client_secret=None)
return request

def _get_client_certificate_request(self, scopes, certificate):
# type: (Sequence[str], AadClientCertificate) -> HttpRequest
# type: (Iterable[str], AadClientCertificate) -> HttpRequest
assertion = self._get_jwt_assertion(certificate)
data = {
"client_assertion": assertion,
Expand All @@ -169,7 +169,7 @@ def _get_client_certificate_request(self, scopes, certificate):
return request

def _get_client_secret_request(self, scopes, secret):
# type: (Sequence[str], str) -> HttpRequest
# type: (Iterable[str], str) -> HttpRequest
data = {
"client_id": self._client_id,
"client_secret": secret,
Expand Down Expand Up @@ -207,7 +207,7 @@ def _get_jwt_assertion(self, certificate):
return jwt_bytes.decode("utf-8")

def _get_refresh_token_request(self, scopes, refresh_token):
# type: (Sequence[str], str) -> HttpRequest
# type: (Iterable[str], str) -> HttpRequest
data = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def authenticate(self, **kwargs):
# type: (**Any) -> AuthenticationRecord
"""Interactively authenticate a user.

:keyword Sequence[str] scopes: scopes to request during authentication, such as those provided by
:keyword Iterable[str] scopes: scopes to request during authentication, such as those provided by
:func:`AuthenticationRequiredError.scopes`. If provided, successful authentication will cache an access token
for these scopes.
:rtype: ~azure.identity.AuthenticationRecord
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Iterable, List, Mapping, Optional, Sequence
from typing import Any, Iterable, List, Mapping, Optional
from .._internal import AadClientBase
from azure.identity import AuthenticationRecord

Expand Down Expand Up @@ -203,7 +203,7 @@ def _get_account(self, username=None, tenant_id=None):
raise CredentialUnavailableError(message=message)

def _get_cached_access_token(self, scopes, account):
# type: (Sequence[str], CacheItem) -> Optional[AccessToken]
# type: (Iterable[str], CacheItem) -> Optional[AccessToken]
if "home_account_id" not in account:
return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Optional, Sequence
from typing import Any, Iterable, Optional
from azure.core.credentials import AccessToken


Expand Down Expand Up @@ -88,7 +88,7 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":

return token

async def _redeem_refresh_token(self, scopes: "Sequence[str]", **kwargs: "Any") -> "Optional[AccessToken]":
async def _redeem_refresh_token(self, scopes: "Iterable[str]", **kwargs: "Any") -> "Optional[AccessToken]":
for refresh_token in self._client.get_cached_refresh_tokens(scopes):
if "secret" not in refresh_token:
continue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, List, Optional, Sequence, Union
from typing import Any, Iterable, List, Optional, Union
from azure.core.credentials import AccessToken
from azure.core.pipeline.policies import AsyncHTTPPolicy, SansIOHTTPPolicy
from azure.core.pipeline.transport import AsyncHttpTransport
Expand All @@ -45,7 +45,7 @@ async def close(self) -> None:

async def obtain_token_by_authorization_code(
self,
scopes: "Sequence[str]",
scopes: "Iterable[str]",
code: str,
redirect_uri: str,
client_secret: "Optional[str]" = None,
Expand All @@ -59,22 +59,22 @@ async def obtain_token_by_authorization_code(
return self._process_response(response, now)

async def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs):
# type: (Sequence[str], AadClientCertificate, **Any) -> AccessToken
# type: (Iterable[str], AadClientCertificate, **Any) -> AccessToken
request = self._get_client_certificate_request(scopes, certificate)
now = int(time.time())
response = await self._pipeline.run(request, stream=False, **kwargs)
return self._process_response(response, now)

async def obtain_token_by_client_secret(
self, scopes: "Sequence[str]", secret: str, **kwargs: "Any"
self, scopes: "Iterable[str]", secret: str, **kwargs: "Any"
) -> "AccessToken":
request = self._get_client_secret_request(scopes, secret)
now = int(time.time())
response = await self._pipeline.run(request, **kwargs)
return self._process_response(response, now)

async def obtain_token_by_refresh_token(
self, scopes: "Sequence[str]", refresh_token: str, **kwargs: "Any"
self, scopes: "Iterable[str]", refresh_token: str, **kwargs: "Any"
) -> "AccessToken":
request = self._get_refresh_token_request(scopes, refresh_token)
now = int(time.time())
Expand Down