Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default impl to handle token challenges #37652

Merged
merged 12 commits into from
Oct 4, 2024
4 changes: 3 additions & 1 deletion sdk/core/azure-core/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Release History

## 1.31.1 (Unreleased)
## 1.32.0 (Unreleased)

### Features Added

- Added a default implementation to handle token challenges in `BearerTokenCredentialPolicy` and `AsyncBearerTokenCredentialPolicy`.

### Breaking Changes

### Bugs Fixed
Expand Down
2 changes: 1 addition & 1 deletion sdk/core/azure-core/azure/core/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
# regenerated.
# --------------------------------------------------------------------------

VERSION = "1.31.1"
VERSION = "1.32.0"
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# license information.
# -------------------------------------------------------------------------
import time
import base64
from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any, Union, cast
from azure.core.credentials import (
TokenCredential,
Expand All @@ -19,6 +20,7 @@
from azure.core.rest import HttpResponse, HttpRequest
from . import HTTPPolicy, SansIOHTTPPolicy
from ...exceptions import ServiceRequestError
from ._utils import get_challenge_parameter

if TYPE_CHECKING:

Expand Down Expand Up @@ -82,13 +84,7 @@ def _need_new_token(self) -> bool:
refresh_on = getattr(self._token, "refresh_on", None)
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300

def _request_token(self, *scopes: str, **kwargs: Any) -> None:
"""Request a new token from the credential.

This will call the credential's appropriate method to get a token and store it in the policy.

:param str scopes: The type of access needed.
"""
def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "AccessTokenInfo"]:
if self._enable_cae:
kwargs.setdefault("enable_cae", self._enable_cae)

Expand All @@ -99,9 +95,17 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> None:
if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member
options[key] = kwargs.pop(key) # type: ignore[literal-required]

self._token = cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options)
else:
self._token = cast(TokenCredential, self._credential).get_token(*scopes, **kwargs)
return cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options)
return cast(TokenCredential, self._credential).get_token(*scopes, **kwargs)

def _request_token(self, *scopes: str, **kwargs: Any) -> None:
"""Request a new token from the credential.

This will call the credential's appropriate method to get a token and store it in the policy.

:param str scopes: The type of access needed.
"""
self._token = self._get_token(*scopes, **kwargs)


class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]):
Expand Down Expand Up @@ -191,6 +195,21 @@ def on_challenge(
:rtype: bool
"""
# pylint:disable=unused-argument
headers = response.http_response.headers
error = get_challenge_parameter(headers, "Bearer", "error")
if error == "insufficient_claims":
encoded_claims = get_challenge_parameter(headers, "Bearer", "claims")
if not encoded_claims:
return False
claims = base64.urlsafe_b64decode(encoded_claims).decode("utf-8")
pvaneck marked this conversation as resolved.
Show resolved Hide resolved
if claims:
try:
token = self._get_token(*self._scopes, claims=claims)
bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], token).token
self._update_headers(request.http_request.headers, bearer_token)
pvaneck marked this conversation as resolved.
Show resolved Hide resolved
return True
except Exception: # pylint:disable=broad-except
return False
return False

def on_response(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# license information.
# -------------------------------------------------------------------------
import time
import base64
from typing import Any, Awaitable, Optional, cast, TypeVar, Union

from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions
Expand All @@ -23,6 +24,7 @@
)
from azure.core.rest import AsyncHttpResponse, HttpRequest
from azure.core.utils._utils import get_running_async_lock
from ._utils import get_challenge_parameter

from .._tools_async import await_result

Expand Down Expand Up @@ -138,6 +140,21 @@ async def on_challenge(
:rtype: bool
"""
# pylint:disable=unused-argument
headers = response.http_response.headers
error = get_challenge_parameter(headers, "Bearer", "error")
if error == "insufficient_claims":
encoded_claims = get_challenge_parameter(headers, "Bearer", "claims")
if not encoded_claims:
return False
claims = base64.urlsafe_b64decode(encoded_claims).decode("utf-8")
if claims:
try:
token = await self._get_token(*self._scopes, claims=claims)
bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], token).token
request.http_request.headers["Authorization"] = "Bearer " + bearer_token
return True
except Exception: # pylint:disable=broad-except
return False
return False

def on_response(
Expand Down Expand Up @@ -169,13 +186,7 @@ def _need_new_token(self) -> bool:
refresh_on = getattr(self._token, "refresh_on", None)
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300

async def _request_token(self, *scopes: str, **kwargs: Any) -> None:
"""Request a new token from the credential.

This will call the credential's appropriate method to get a token and store it in the policy.

:param str scopes: The type of access needed.
"""
async def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "AccessTokenInfo"]:
if self._enable_cae:
kwargs.setdefault("enable_cae", self._enable_cae)

Expand All @@ -186,14 +197,22 @@ async def _request_token(self, *scopes: str, **kwargs: Any) -> None:
if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member
options[key] = kwargs.pop(key) # type: ignore[literal-required]

self._token = await await_result(
return await await_result(
cast(AsyncSupportsTokenInfo, self._credential).get_token_info,
*scopes,
options=options,
)
else:
self._token = await await_result(
cast(AsyncTokenCredential, self._credential).get_token,
*scopes,
**kwargs,
)
return await await_result(
cast(AsyncTokenCredential, self._credential).get_token,
*scopes,
**kwargs,
)

async def _request_token(self, *scopes: str, **kwargs: Any) -> None:
"""Request a new token from the credential.

This will call the credential's appropriate method to get a token and store it in the policy.

:param str scopes: The type of access needed.
"""
self._token = await self._get_token(*scopes, **kwargs)
100 changes: 99 additions & 1 deletion sdk/core/azure-core/azure/core/pipeline/policies/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# --------------------------------------------------------------------------
import datetime
import email.utils
from typing import Optional, cast, Union
from typing import Optional, cast, Union, Tuple
from urllib.parse import urlparse

from azure.core.pipeline.transport import (
Expand Down Expand Up @@ -102,3 +102,101 @@ def get_domain(url: str) -> str:
:return: The domain of the url.
"""
return str(urlparse(url).netloc).lower()


def get_challenge_parameter(headers, challenge_scheme: str, challenge_parameter: str) -> Optional[str]:
"""
Parses the specified parameter from a challenge header found in the response.

:param dict[str, str] headers: The response headers to parse.
:param str challenge_scheme: The challenge scheme containing the challenge parameter, e.g., "Bearer".
:param str challenge_parameter: The parameter key name to search for.
:return: The value of the parameter name if found.
:rtype: str or None
"""
header_value = headers.get("WWW-Authenticate")
if not header_value:
return None

scheme = challenge_scheme
parameter = challenge_parameter
header_span = header_value

# Iterate through each challenge value.
while get_next_challenge(header_span):
xiangyan99 marked this conversation as resolved.
Show resolved Hide resolved
challenge = get_next_challenge(header_span)
if not challenge:
break
challenge_key, header_span = challenge
xiangyan99 marked this conversation as resolved.
Show resolved Hide resolved
# Enumerate each key-value parameter until we find the parameter key on the specified scheme challenge.
while get_next_parameter(header_span):
xiangyan99 marked this conversation as resolved.
Show resolved Hide resolved
parameters = get_next_parameter(header_span)
if not parameters:
break
key, value, header_span = parameters
if challenge_key.lower() == scheme.lower() and key.lower() == parameter.lower():
xiangyan99 marked this conversation as resolved.
Show resolved Hide resolved
return value

return None


def get_next_challenge(header_value: str) -> Optional[Tuple[str, str]]:
"""
Iterates through the challenge schemes present in a challenge header.

:param str header_value: The header value which will be sliced to remove the first parsed challenge key.
:return: The parsed challenge scheme and the remaining header value.
:rtype: tuple[str, str] or None
"""
header_value = header_value.lstrip(" ")
end_of_challenge_key = header_value.find(" ")

if end_of_challenge_key < 0:
return None

challenge_key = header_value[:end_of_challenge_key]
header_value = header_value[end_of_challenge_key + 1 :]

return challenge_key, header_value


def get_next_parameter(header_value: str, separator: str = "=") -> Optional[Tuple[str, str, str]]:
"""
Iterates through a challenge header value to extract key-value parameters.

:param str header_value: The header value after being parsed by get_next_challenge.
:param str separator: The challenge parameter key-value pair separator, default is '='.
:return: The next available challenge parameter as a tuple (param_key, param_value, remaining header_value).
:rtype: tuple[str, str, str] or None
"""
space_or_comma = " ,"
header_value = header_value.lstrip(space_or_comma)

next_space = header_value.find(" ")
next_separator = header_value.find(separator)

if next_space < next_separator and next_space != -1:
return None

if next_separator < 0:
return None

param_key = header_value[:next_separator].strip()
header_value = header_value[next_separator + 1 :]

quote_index = header_value.find('"')

if quote_index >= 0:
header_value = header_value[quote_index + 1 :]
param_value = header_value[: header_value.find('"')]
else:
trailing_delimiter_index = header_value.find(" ")
if trailing_delimiter_index >= 0:
param_value = header_value[:trailing_delimiter_index]
else:
param_value = header_value

if header_value != param_value:
header_value = header_value[len(param_value) + 1 :]

return param_key, param_value, header_value
26 changes: 25 additions & 1 deletion sdk/core/azure-core/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
from azure.core.utils import case_insensitive_dict
from azure.core.utils._utils import get_running_async_lock
from azure.core.pipeline.policies._utils import parse_retry_after
from azure.core.pipeline.policies._utils import parse_retry_after, get_challenge_parameter


@pytest.fixture()
Expand Down Expand Up @@ -146,3 +146,27 @@ def test_parse_retry_after():
assert ret == 0
ret = parse_retry_after("0.9")
assert ret == 0.9


def test_get_challenge_parameter():
headers = {
"WWW-Authenticate": 'Bearer authorization_uri="https://login.microsoftonline.com/tenant-id", resource="https://vault.azure.net"'
}
assert (
get_challenge_parameter(headers, "Bearer", "authorization_uri") == "https://login.microsoftonline.com/tenant-id"
)
assert get_challenge_parameter(headers, "Bearer", "resource") == "https://vault.azure.net"
assert get_challenge_parameter(headers, "Bearer", "foo") is None

headers = {
"WWW-Authenticate": 'Bearer realm="", authorization_uri="https://login.microsoftonline.com/common/oauth2/authorize", error="insufficient_claims", claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ=="'
}
xiangyan99 marked this conversation as resolved.
Show resolved Hide resolved
assert (
get_challenge_parameter(headers, "Bearer", "authorization_uri")
== "https://login.microsoftonline.com/common/oauth2/authorize"
)
assert get_challenge_parameter(headers, "Bearer", "error") == "insufficient_claims"
assert (
get_challenge_parameter(headers, "Bearer", "claims")
== "eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ=="
)
pvaneck marked this conversation as resolved.
Show resolved Hide resolved
Loading