-
Notifications
You must be signed in to change notification settings - Fork 301
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
OAuth2 support for flyte-cli and SDK engine (#23)
This change adds authentication support for flyte-cli and the pyflyte CLIs. # New authorization code Specifically this change introduces an **AuthorizationClient** which implements the [PKCE authorization flow](https://www.oauth.com/oauth2-servers/pkce/authorization-code-exchange/) for untrusted clients. This client handles requesting an initial access token, spinning up a callback server to receive the access token and using that to retrieve an authorization code. The client also handles refreshing expired authorization tokens. This change also includes a lightweight **DiscoveryClient** for retrieving authorization endpoint metadata defined in the [OAuth 2.0 Authorization Server Metadata](https://tools.ietf.org/id/draft-ietf-oauth-discovery-08.html) draft document. An authorization client singleton is lazily initialized for use by flyte-cli. # Pyflyte changes (basic auth) Requests an authorization token using a username and password. # Flyte-cli changes (standard auth) Requests an authorization token using the PKCE flow. # Raw client changes Wraps RPC calls to flyteadmin in a retry handler that initiates the appropriate authentication flow defined in the flytekit config in response to `HTTP 401 unauthorized` response codes.
- Loading branch information
1 parent
603a918
commit 439a36d
Showing
24 changed files
with
941 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from __future__ import absolute_import | ||
import flytekit.plugins | ||
|
||
__version__ = '0.3.1' | ||
__version__ = '0.4.0' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,280 @@ | ||
import base64 as _base64 | ||
import hashlib as _hashlib | ||
import keyring as _keyring | ||
import os as _os | ||
import re as _re | ||
import requests as _requests | ||
import webbrowser as _webbrowser | ||
|
||
from multiprocessing import Process as _Process, Queue as _Queue | ||
|
||
try: # Python 3.5+ | ||
from http import HTTPStatus as _StatusCodes | ||
except ImportError: | ||
try: # Python 3 | ||
from http import client as _StatusCodes | ||
except ImportError: # Python 2 | ||
import httplib as _StatusCodes | ||
try: # Python 3 | ||
import http.server as _BaseHTTPServer | ||
except ImportError: # Python 2 | ||
import BaseHTTPServer as _BaseHTTPServer | ||
|
||
try: # Python 3 | ||
import urllib.parse as _urlparse | ||
from urllib.parse import urlencode as _urlencode | ||
except ImportError: # Python 2 | ||
import urlparse as _urlparse | ||
from urllib import urlencode as _urlencode | ||
|
||
_code_verifier_length = 64 | ||
_random_seed_length = 40 | ||
_utf_8 = 'utf-8' | ||
|
||
|
||
# Identifies the service used for storing passwords in keyring | ||
_keyring_service_name = "flyteauth" | ||
# Identifies the key used for storing and fetching from keyring. In our case, instead of a username as the keyring docs | ||
# suggest, we are storing a user's oidc. | ||
_keyring_access_token_storage_key = "access_token" | ||
_keyring_refresh_token_storage_key = "refresh_token" | ||
|
||
|
||
def _generate_code_verifier(): | ||
""" | ||
Generates a 'code_verifier' as described in https://tools.ietf.org/html/rfc7636#section-4.1 | ||
Adapted from https://github.com/openstack/deb-python-oauth2client/blob/master/oauth2client/_pkce.py. | ||
:return str: | ||
""" | ||
code_verifier = _base64.urlsafe_b64encode(_os.urandom(_code_verifier_length)).decode(_utf_8) | ||
# Eliminate invalid characters. | ||
code_verifier = _re.sub(r'[^a-zA-Z0-9_\-.~]+', '', code_verifier) | ||
if len(code_verifier) < 43: | ||
raise ValueError("Verifier too short. number of bytes must be > 30.") | ||
elif len(code_verifier) > 128: | ||
raise ValueError("Verifier too long. number of bytes must be < 97.") | ||
return code_verifier | ||
|
||
|
||
def _generate_state_parameter(): | ||
state = _base64.urlsafe_b64encode(_os.urandom(_random_seed_length)).decode(_utf_8) | ||
# Eliminate invalid characters. | ||
code_verifier = _re.sub('[^a-zA-Z0-9-_.,]+', '', state) | ||
return code_verifier | ||
|
||
|
||
def _create_code_challenge(code_verifier): | ||
""" | ||
Adapted from https://github.com/openstack/deb-python-oauth2client/blob/master/oauth2client/_pkce.py. | ||
:param str code_verifier: represents a code verifier generated by generate_code_verifier() | ||
:return str: urlsafe base64-encoded sha256 hash digest | ||
""" | ||
code_challenge = _hashlib.sha256(code_verifier.encode(_utf_8)).digest() | ||
code_challenge = _base64.urlsafe_b64encode(code_challenge).decode(_utf_8) | ||
# Eliminate invalid characters | ||
code_challenge = code_challenge.replace('=', '') | ||
return code_challenge | ||
|
||
|
||
class AuthorizationCode(object): | ||
def __init__(self, code, state): | ||
self._code = code | ||
self._state = state | ||
|
||
@property | ||
def code(self): | ||
return self._code | ||
|
||
@property | ||
def state(self): | ||
return self._state | ||
|
||
|
||
class OAuthCallbackHandler(_BaseHTTPServer.BaseHTTPRequestHandler): | ||
""" | ||
A simple wrapper around BaseHTTPServer.BaseHTTPRequestHandler that handles a callback URL that accepts an | ||
authorization token. | ||
""" | ||
|
||
def do_GET(self): | ||
url = _urlparse.urlparse(self.path) | ||
if url.path == self.server.redirect_path: | ||
self.send_response(_StatusCodes.OK) | ||
self.end_headers() | ||
self.handle_login(dict(_urlparse.parse_qsl(url.query))) | ||
else: | ||
self.send_response(_StatusCodes.NOT_FOUND) | ||
|
||
def handle_login(self, data): | ||
self.server.handle_authorization_code(AuthorizationCode(data['code'], data['state'])) | ||
|
||
|
||
class OAuthHTTPServer(_BaseHTTPServer.HTTPServer): | ||
""" | ||
A simple wrapper around the BaseHTTPServer.HTTPServer implementation that binds an authorization_client for handling | ||
authorization code callbacks. | ||
""" | ||
def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True, | ||
redirect_path=None, queue=None): | ||
_BaseHTTPServer.HTTPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate) | ||
self._redirect_path = redirect_path | ||
self._auth_code = None | ||
self._queue = queue | ||
|
||
@property | ||
def redirect_path(self): | ||
return self._redirect_path | ||
|
||
def handle_authorization_code(self, auth_code): | ||
self._queue.put(auth_code) | ||
|
||
|
||
class Credentials(object): | ||
def __init__(self, access_token=None): | ||
self._access_token = access_token | ||
|
||
@property | ||
def access_token(self): | ||
return self._access_token | ||
|
||
|
||
class AuthorizationClient(object): | ||
def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redirect_uri=None): | ||
self._auth_endpoint = auth_endpoint | ||
self._token_endpoint = token_endpoint | ||
self._client_id = client_id | ||
self._redirect_uri = redirect_uri | ||
self._code_verifier = _generate_code_verifier() | ||
code_challenge = _create_code_challenge(self._code_verifier) | ||
self._code_challenge = code_challenge | ||
state = _generate_state_parameter() | ||
self._state = state | ||
self._credentials = None | ||
self._refresh_token = None | ||
self._headers = {'content-type': "application/x-www-form-urlencoded"} | ||
self._expired = False | ||
|
||
self._params = { | ||
"client_id": client_id, # This must match the Client ID of the OAuth application. | ||
"response_type": "code", # Indicates the authorization code grant | ||
"scope": "openid offline_access", # ensures that the /token endpoint returns an ID and refresh token | ||
# callback location where the user-agent will be directed to. | ||
"redirect_uri": self._redirect_uri, | ||
"state": state, | ||
"code_challenge": code_challenge, | ||
"code_challenge_method": "S256", | ||
} | ||
|
||
# Prefer to use already-fetched token values when they've been set globally. | ||
self._refresh_token = _keyring.get_password(_keyring_service_name, _keyring_refresh_token_storage_key) | ||
access_token = _keyring.get_password(_keyring_service_name, _keyring_access_token_storage_key) | ||
if access_token: | ||
self._credentials = Credentials(access_token=access_token) | ||
return | ||
|
||
# In the absence of globally-set token values, initiate the token request flow | ||
q = _Queue() | ||
# First prepare the callback server in the background | ||
server = self._create_callback_server(q) | ||
server_process = _Process(target=server.handle_request) | ||
server_process.start() | ||
|
||
# Send the call to request the authorization code | ||
self._request_authorization_code() | ||
|
||
# Request the access token once the auth code has been received. | ||
auth_code = q.get() | ||
server_process.terminate() | ||
self.request_access_token(auth_code) | ||
|
||
def _create_callback_server(self, q): | ||
server_url = _urlparse.urlparse(self._redirect_uri) | ||
server_address = (server_url.hostname, server_url.port) | ||
return OAuthHTTPServer(server_address, OAuthCallbackHandler, redirect_path=server_url.path, queue=q) | ||
|
||
def _request_authorization_code(self): | ||
scheme, netloc, path, _, _, _ = _urlparse.urlparse(self._auth_endpoint) | ||
query = _urlencode(self._params) | ||
endpoint = _urlparse.urlunparse((scheme, netloc, path, None, query, None)) | ||
_webbrowser.open_new_tab(endpoint) | ||
|
||
def _initialize_credentials(self, auth_token_resp): | ||
|
||
""" | ||
The auth_token_resp body is of the form: | ||
{ | ||
"access_token": "foo", | ||
"refresh_token": "bar", | ||
"id_token": "baz", | ||
"token_type": "Bearer" | ||
} | ||
""" | ||
response_body = auth_token_resp.json() | ||
if "access_token" not in response_body: | ||
raise ValueError('Expected "access_token" in response from oauth server') | ||
if "refresh_token" in response_body: | ||
self._refresh_token = response_body["refresh_token"] | ||
|
||
access_token = response_body["access_token"] | ||
refresh_token = response_body["refresh_token"] | ||
|
||
_keyring.set_password(_keyring_service_name, _keyring_access_token_storage_key, access_token) | ||
_keyring.set_password(_keyring_service_name, _keyring_refresh_token_storage_key, refresh_token) | ||
self._credentials = Credentials(access_token=access_token) | ||
|
||
def request_access_token(self, auth_code): | ||
if self._state != auth_code.state: | ||
raise ValueError("Unexpected state parameter [{}] passed".format(auth_code.state)) | ||
self._params.update({ | ||
"code": auth_code.code, | ||
"code_verifier": self._code_verifier, | ||
"grant_type": "authorization_code", | ||
}) | ||
resp = _requests.post( | ||
url=self._token_endpoint, | ||
data=self._params, | ||
headers=self._headers, | ||
allow_redirects=False | ||
) | ||
if resp.status_code != _StatusCodes.OK: | ||
# TODO: handle expected (?) error cases: | ||
# https://auth0.com/docs/flows/guides/device-auth/call-api-device-auth#token-responses | ||
raise Exception('Failed to request access token with response: [{}] {}'.format( | ||
resp.status_code, resp.content)) | ||
self._initialize_credentials(resp) | ||
|
||
def refresh_access_token(self): | ||
if self._refresh_token is None: | ||
raise ValueError("no refresh token available with which to refresh authorization credentials") | ||
|
||
resp = _requests.post( | ||
url=self._token_endpoint, | ||
data={'grant_type': 'refresh_token', | ||
'client_id': self._client_id, | ||
'refresh_token': self._refresh_token}, | ||
headers=self._headers, | ||
allow_redirects=False | ||
) | ||
if resp.status_code != _StatusCodes.OK: | ||
self._expired = True | ||
# In the absence of a successful response, assume the refresh token is expired. This should indicate | ||
# to the caller that the AuthorizationClient is defunct and a new one needs to be re-initialized. | ||
|
||
_keyring.delete_password(_keyring_service_name, _keyring_access_token_storage_key) | ||
_keyring.delete_password(_keyring_service_name, _keyring_refresh_token_storage_key) | ||
return | ||
self._initialize_credentials(resp) | ||
|
||
@property | ||
def credentials(self): | ||
""" | ||
:return flytekit.clis.auth.auth.Credentials: | ||
""" | ||
return self._credentials | ||
|
||
@property | ||
def expired(self): | ||
""" | ||
:return bool: | ||
""" | ||
return self._expired |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from __future__ import absolute_import | ||
from flytekit.clis.auth.auth import AuthorizationClient as _AuthorizationClient | ||
from flytekit.clis.auth.discovery import DiscoveryClient as _DiscoveryClient | ||
|
||
from flytekit.configuration.creds import ( | ||
REDIRECT_URI as _REDIRECT_URI, | ||
CLIENT_ID as _CLIENT_ID | ||
) | ||
from flytekit.configuration.platform import URL as _URL, INSECURE as _INSECURE | ||
|
||
try: # Python 3 | ||
import urllib.parse as _urlparse | ||
except ImportError: # Python 2 | ||
import urlparse as _urlparse | ||
|
||
# Default, well known-URI string used for fetching JSON metadata. See https://tools.ietf.org/html/rfc8414#section-3. | ||
discovery_endpoint_path = ".well-known/oauth-authorization-server" | ||
|
||
|
||
def _get_discovery_endpoint(): | ||
if _INSECURE.get(): | ||
return _urlparse.urljoin('http://{}/'.format(_URL.get()), discovery_endpoint_path) | ||
return _urlparse.urljoin('https://{}/'.format(_URL.get()), discovery_endpoint_path) | ||
|
||
|
||
# Lazy initialized authorization client singleton | ||
_authorization_client = None | ||
|
||
|
||
def get_client(): | ||
global _authorization_client | ||
if _authorization_client is not None and not _authorization_client.expired: | ||
return _authorization_client | ||
authorization_endpoints = get_authorization_endpoints() | ||
|
||
_authorization_client =\ | ||
_AuthorizationClient(redirect_uri=_REDIRECT_URI.get(), client_id=_CLIENT_ID.get(), | ||
auth_endpoint=authorization_endpoints.auth_endpoint, | ||
token_endpoint=authorization_endpoints.token_endpoint) | ||
return _authorization_client | ||
|
||
|
||
def get_authorization_endpoints(): | ||
discovery_endpoint = _get_discovery_endpoint() | ||
discovery_client = _DiscoveryClient(discovery_url=discovery_endpoint) | ||
return discovery_client.get_authorization_endpoints() |
Oops, something went wrong.