Skip to content

Commit

Permalink
Issue #191: add support for OIDC device code flow with PKCE
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Mar 18, 2021
1 parent 80fda18 commit 9f5596e
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 37 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add initial/experimental support for OIDC device code flow with PKCE (alternative for client secret) ([#191](https://github.com/Open-EO/openeo-python-client/issues/191) / EP-3700)

### Changed

### Removed
Expand Down
76 changes: 46 additions & 30 deletions openeo/rest/auth/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,8 @@ def drain_queue(queue: Queue, initial_timeout: float = 10, item_minimum: int = 1
def random_string(length=32, characters: str = None):
"""
Build a random string from given characters (alphanumeric by default)
TODO: move this to a utils module?
"""
# TODO: move this to a utils module?
characters = characters or (string.ascii_letters + string.digits)
return "".join(random.choice(characters) for _ in range(length))

Expand Down Expand Up @@ -339,6 +338,28 @@ def _extract_token(data: dict, key: str, expected_nonce: str = None, allow_absen
return token


class PkceCode:
"""
Simple container for PKCE code verifier and code challenge.
PKCE, pronounced "pixy", is short for "Proof Key for Code Exchange".
Also see https://tools.ietf.org/html/rfc7636
"""
__slots__ = ["code_verifier", "code_challenge", "code_challenge_method"]

def __init__(self):
self.code_verifier = random_string(64)
# Only SHA256 is supported for now.
self.code_challenge_method = "S256"
self.code_challenge = PkceCode.sha256_hash(self.code_verifier)

@staticmethod
def sha256_hash(code: str) -> str:
"""Apply SHA256 hash to code verifier to get code challenge"""
data = hashlib.sha256(code.encode('ascii')).digest()
return base64.urlsafe_b64encode(data).decode('ascii').replace('=', '')


AuthCodeResult = namedtuple("AuthCodeResult", ["auth_code", "nonce", "code_verifier", "redirect_uri"])


Expand Down Expand Up @@ -376,28 +397,14 @@ def __init__(
self._authentication_timeout = timeout or self.TIMEOUT_DEFAULT
self._server_address = server_address

@staticmethod
def hash_code_verifier(code: str) -> str:
"""Hash code verifier to code challenge"""
return base64.urlsafe_b64encode(
hashlib.sha256(code.encode('ascii')).digest()
).decode('ascii').replace('=', '')

@staticmethod
def get_pkce_codes() -> Tuple[str, str]:
"""Build random PKCE code verifier and challenge"""
code_verifier = random_string(64)
code_challenge = OidcAuthCodePkceAuthenticator.hash_code_verifier(code_verifier)
return code_verifier, code_challenge

def _get_auth_code(self, request_refresh_token: bool = False) -> AuthCodeResult:
"""
Do OAuth authentication request and catch redirect to extract authentication code
:return:
"""
state = random_string(32)
nonce = random_string(21)
code_verifier, code_challenge = self.get_pkce_codes()
pkce = PkceCode()

# Set up HTTP server (in separate thread) to catch OAuth redirect URL
callback_queue = Queue()
Expand Down Expand Up @@ -426,8 +433,8 @@ def _get_auth_code(self, request_refresh_token: bool = False) -> AuthCodeResult:
"redirect_uri": redirect_uri,
"state": state,
"nonce": nonce,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
"code_challenge": pkce.code_challenge,
"code_challenge_method": pkce.code_challenge_method,
})
)
log.info("Sending user to auth URL {u!r}".format(u=auth_url))
Expand Down Expand Up @@ -468,7 +475,7 @@ def _get_auth_code(self, request_refresh_token: bool = False) -> AuthCodeResult:
auth_code = redirect_params["code"][0]

return AuthCodeResult(
auth_code=auth_code, nonce=nonce, code_verifier=code_verifier, redirect_uri=redirect_uri
auth_code=auth_code, nonce=nonce, code_verifier=pkce.code_verifier, redirect_uri=redirect_uri
)

def get_tokens(self, request_refresh_token: bool = False) -> AccessTokenResult:
Expand Down Expand Up @@ -560,23 +567,29 @@ class OidcDeviceAuthenticator(OidcAuthenticator):

grant_type = "urn:ietf:params:oauth:grant-type:device_code"

def __init__(self, client_info: OidcClientInfo, display: Callable[[str], None] = print, device_code_url: str = None,
max_poll_time=5 * 60):
def __init__(
self, client_info: OidcClientInfo, display: Callable[[str], None] = print, device_code_url: str = None,
max_poll_time=5 * 60, use_pkce: bool = False
):
super().__init__(client_info=client_info)
self._display = display
# Allow to specify/override device code URL for cases when it is not available in OIDC discovery doc.
self._device_code_url = device_code_url or self._provider_config["device_authorization_endpoint"]
self._max_poll_time = max_poll_time
# TODO: automatically use PKCE if there is no client secret?
# TODO: detect if OIDC provider supports device flow + PKCE? E.g. get this from `OidcProviderInfo` (also see https://github.com/Open-EO/openeo-api/pull/366)?
self._pkce = PkceCode() if use_pkce else None

def _get_verification_info(self, request_refresh_token: bool = False) -> VerificationInfo:
"""Get verification URL and user code"""
resp = requests.post(
url=self._device_code_url,
data={
"client_id": self.client_id,
"scope": self._client_info.provider.get_scopes_string(request_refresh_token=request_refresh_token)
}
)
post_data = {
"client_id": self.client_id,
"scope": self._client_info.provider.get_scopes_string(request_refresh_token=request_refresh_token)
}
if self._pkce:
post_data["code_challenge"] = self._pkce.code_challenge,
post_data["code_challenge_method"] = self._pkce.code_challenge_method
resp = requests.post(url=self._device_code_url, data=post_data)
if resp.status_code != 200:
raise OidcException("Failed to get verification URL and user code from {u!r}: {s} {r!r} {t!r}".format(
s=resp.status_code, r=resp.reason, u=resp.url, t=resp.text
Expand Down Expand Up @@ -605,10 +618,13 @@ def get_tokens(self, request_refresh_token: bool = False) -> AccessTokenResult:
token_endpoint = self._provider_config['token_endpoint']
post_data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"device_code": verification_info.device_code,
"grant_type": self.grant_type
}
if self._pkce:
post_data["code_verifier"] = self._pkce.code_verifier
else:
post_data["client_secret"] = self.client_secret
poll_interval = verification_info.interval
while elapsed() <= self._max_poll_time:
time.sleep(poll_interval)
Expand Down
11 changes: 8 additions & 3 deletions openeo/rest/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,17 +468,22 @@ def authenticate_oidc_refresh_token(
return self._authenticate_oidc(authenticator, provider_id=provider_id)

def authenticate_oidc_device(
self, client_id: str=None, client_secret: str=None, provider_id: str = None,
store_refresh_token=False,
self, client_id: str = None, client_secret: str = None, provider_id: str = None,
store_refresh_token=False, use_pkce: bool = False,
**kwargs
) -> 'Connection':
"""
Authenticate with OAuth Device Authorization grant/flow
:param use_pkce: Use PKCE instead of client secret.
Note that this features is not widely supported among OIDC providers.
.. versionchanged:: 0.5.1 Add :py:obj:`use_pkce` argument
"""
provider_id, client_info = self._get_oidc_provider_and_client_info(
provider_id=provider_id, client_id=client_id, client_secret=client_secret
)
authenticator = OidcDeviceAuthenticator(client_info=client_info, **kwargs)
authenticator = OidcDeviceAuthenticator(client_info=client_info, use_pkce=use_pkce, **kwargs)
return self._authenticate_oidc(authenticator, provider_id=provider_id, store_refresh_token=store_refresh_token)

def describe_account(self) -> str:
Expand Down
54 changes: 50 additions & 4 deletions tests/rest/auth/test_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import openeo.rest.auth.oidc
from openeo.rest.auth.oidc import QueuingRequestHandler, drain_queue, HttpServerThread, OidcAuthCodePkceAuthenticator, \
OidcClientCredentialsAuthenticator, OidcResourceOwnerPasswordAuthenticator, OidcClientInfo, OidcProviderInfo, \
OidcDeviceAuthenticator, random_string, OidcRefreshTokenAuthenticator
OidcDeviceAuthenticator, random_string, OidcRefreshTokenAuthenticator, PkceCode
from openeo.util import dict_no_none


Expand Down Expand Up @@ -211,7 +211,7 @@ def token_callback_authorization_code(self, request: requests_mock.request._Requ
params = self._get_query_params(query=request.text)
assert params["client_id"] == self.expected_client_id
assert params["grant_type"] == "authorization_code"
assert self.state["code_challenge"] == OidcAuthCodePkceAuthenticator.hash_code_verifier(params["code_verifier"])
assert self.state["code_challenge"] == PkceCode.sha256_hash(params["code_verifier"])
assert params["code"] == self.expected_authorization_code
assert params["redirect_uri"] == self.state["redirect_uri"]
return self._build_token_response()
Expand Down Expand Up @@ -241,6 +241,8 @@ def device_code_callback(self, request: requests_mock.request._RequestObjectProx
self.state["device_code"] = random_string()
self.state["user_code"] = random_string(length=6).upper()
self.state["scope"] = params["scope"]
if "code_challenge" in params:
self.state["code_challenge"] = params["code_challenge"]
return json.dumps({
# TODO: also verification_url (google tweak)
"verification_uri": self.provider_root_url + "/dc",
Expand All @@ -252,7 +254,15 @@ def device_code_callback(self, request: requests_mock.request._RequestObjectProx
def token_callback_device_code(self, request: requests_mock.request._RequestObjectProxy, context):
params = self._get_query_params(query=request.text)
assert params["client_id"] == self.expected_client_id
assert params["client_secret"] == self.expected_fields["client_secret"]
expected_client_secret = self.expected_fields.get("client_secret")
if expected_client_secret:
assert params["client_secret"] == expected_client_secret
expect_code_verifier = bool(self.expected_fields.get("code_verifier"))
if expect_code_verifier:
assert PkceCode.sha256_hash(params["code_verifier"]) == self.state["code_challenge"]
self.state["code_verifier"] = params["code_verifier"]
if bool(expected_client_secret) == expect_code_verifier:
pytest.fail("Token callback should either have client secret or PKCE code verifier")
assert params["device_code"] == self.state["device_code"]
assert params["grant_type"] == "urn:ietf:params:oauth:grant-type:device_code"
# Fail with pending/too fast?
Expand Down Expand Up @@ -385,7 +395,7 @@ def test_oidc_resource_owner_password_credentials_flow(requests_mock):
assert oidc_mock.state["access_token"] == tokens.access_token


def test_oidc_device_flow(requests_mock, caplog):
def test_oidc_device_flow_with_client_secret(requests_mock, caplog):
client_id = "myclient"
client_secret = "$3cr3t"
oidc_discovery_url = "http://oidc.test/.well-known/openid-configuration"
Expand Down Expand Up @@ -421,6 +431,42 @@ def test_oidc_device_flow(requests_mock, caplog):
)


def test_oidc_device_flow_with_pkce(requests_mock, caplog):
client_id = "myclient"
oidc_discovery_url = "http://oidc.test/.well-known/openid-configuration"
oidc_mock = OidcMock(
requests_mock=requests_mock,
expected_grant_type="urn:ietf:params:oauth:grant-type:device_code",
expected_client_id=client_id,
oidc_discovery_url=oidc_discovery_url,
expected_fields={"scope": "df openid", "code_verifier": True},
state={"device_code_callback_timeline": ["authorization_pending", "slow_down", "great success"]},
scopes_supported=["openid", "df"]
)
provider = OidcProviderInfo(discovery_url=oidc_discovery_url, scopes=["df"])
display = []
authenticator = OidcDeviceAuthenticator(
client_info=OidcClientInfo(client_id=client_id, provider=provider),
display=display.append,
use_pkce=True
)
with mock.patch.object(openeo.rest.auth.oidc.time, "sleep") as sleep:
with caplog.at_level(logging.INFO):
tokens = authenticator.get_tokens()
assert oidc_mock.state["access_token"] == tokens.access_token
assert re.search(
r"visit https://auth\.test/dc and enter the user code {c!r}".format(c=oidc_mock.state['user_code']),
display[0]
)
assert display[1] == "Authorized successfully."
assert sleep.mock_calls == [mock.call(2), mock.call(2), mock.call(7)]
assert re.search(
r"Authorization pending\..*Polling too fast, will slow down\..*Authorized successfully\.",
caplog.text,
flags=re.DOTALL
)


def test_oidc_refresh_token_flow(requests_mock, caplog):
client_id = "myclient"
client_secret = "$3cr3t"
Expand Down

0 comments on commit 9f5596e

Please sign in to comment.