Skip to content

Commit

Permalink
OAuth2 support for flyte-cli and SDK engine (#23)
Browse files Browse the repository at this point in the history
This change adds authentication support for flyte-cli and the pyflyte CLIs.

# New authorization code
Specifically this change introduces an **AuthorizationClient** which implements the [PKCE authorization flow](https://www.oauth.com/oauth2-servers/pkce/authorization-code-exchange/) for untrusted clients. This client handles requesting an initial access token, spinning up a callback server to receive the access token and using that to retrieve an authorization code. The client also handles refreshing expired authorization tokens.

This change also includes a lightweight **DiscoveryClient** for retrieving authorization endpoint metadata defined in the [OAuth 2.0 Authorization Server Metadata](https://tools.ietf.org/id/draft-ietf-oauth-discovery-08.html) draft document.

An authorization client singleton is lazily initialized for use by flyte-cli.

# Pyflyte changes (basic auth)
Requests an authorization token using a username and password.

# Flyte-cli changes (standard auth)
Requests an authorization token using the PKCE flow.

# Raw client changes
Wraps RPC calls to flyteadmin in a retry handler that initiates the appropriate authentication flow defined in the flytekit config in response to `HTTP 401 unauthorized` response codes.
  • Loading branch information
katrogan authored and wild-endeavor committed Dec 6, 2019
1 parent 603a918 commit 439a36d
Show file tree
Hide file tree
Showing 24 changed files with 941 additions and 44 deletions.
2 changes: 1 addition & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import absolute_import
import flytekit.plugins

__version__ = '0.3.1'
__version__ = '0.4.0'
4 changes: 4 additions & 0 deletions flytekit/clients/helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@

from flytekit.clis.auth import credentials as _credentials_access



def iterate_node_executions(
client,
Expand Down Expand Up @@ -75,3 +78,4 @@ def iterate_task_executions(client, node_execution_identifier, limit=None, filte
if not next_token:
break
token = next_token

142 changes: 114 additions & 28 deletions flytekit/clients/raw.py

Large diffs are not rendered by default.

Empty file added flytekit/clis/auth/__init__.py
Empty file.
280 changes: 280 additions & 0 deletions flytekit/clis/auth/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
import base64 as _base64
import hashlib as _hashlib
import keyring as _keyring
import os as _os
import re as _re
import requests as _requests
import webbrowser as _webbrowser

from multiprocessing import Process as _Process, Queue as _Queue

try: # Python 3.5+
from http import HTTPStatus as _StatusCodes
except ImportError:
try: # Python 3
from http import client as _StatusCodes
except ImportError: # Python 2
import httplib as _StatusCodes
try: # Python 3
import http.server as _BaseHTTPServer
except ImportError: # Python 2
import BaseHTTPServer as _BaseHTTPServer

try: # Python 3
import urllib.parse as _urlparse
from urllib.parse import urlencode as _urlencode
except ImportError: # Python 2
import urlparse as _urlparse
from urllib import urlencode as _urlencode

_code_verifier_length = 64
_random_seed_length = 40
_utf_8 = 'utf-8'


# Identifies the service used for storing passwords in keyring
_keyring_service_name = "flyteauth"
# Identifies the key used for storing and fetching from keyring. In our case, instead of a username as the keyring docs
# suggest, we are storing a user's oidc.
_keyring_access_token_storage_key = "access_token"
_keyring_refresh_token_storage_key = "refresh_token"


def _generate_code_verifier():
"""
Generates a 'code_verifier' as described in https://tools.ietf.org/html/rfc7636#section-4.1
Adapted from https://github.com/openstack/deb-python-oauth2client/blob/master/oauth2client/_pkce.py.
:return str:
"""
code_verifier = _base64.urlsafe_b64encode(_os.urandom(_code_verifier_length)).decode(_utf_8)
# Eliminate invalid characters.
code_verifier = _re.sub(r'[^a-zA-Z0-9_\-.~]+', '', code_verifier)
if len(code_verifier) < 43:
raise ValueError("Verifier too short. number of bytes must be > 30.")
elif len(code_verifier) > 128:
raise ValueError("Verifier too long. number of bytes must be < 97.")
return code_verifier


def _generate_state_parameter():
state = _base64.urlsafe_b64encode(_os.urandom(_random_seed_length)).decode(_utf_8)
# Eliminate invalid characters.
code_verifier = _re.sub('[^a-zA-Z0-9-_.,]+', '', state)
return code_verifier


def _create_code_challenge(code_verifier):
"""
Adapted from https://github.com/openstack/deb-python-oauth2client/blob/master/oauth2client/_pkce.py.
:param str code_verifier: represents a code verifier generated by generate_code_verifier()
:return str: urlsafe base64-encoded sha256 hash digest
"""
code_challenge = _hashlib.sha256(code_verifier.encode(_utf_8)).digest()
code_challenge = _base64.urlsafe_b64encode(code_challenge).decode(_utf_8)
# Eliminate invalid characters
code_challenge = code_challenge.replace('=', '')
return code_challenge


class AuthorizationCode(object):
def __init__(self, code, state):
self._code = code
self._state = state

@property
def code(self):
return self._code

@property
def state(self):
return self._state


class OAuthCallbackHandler(_BaseHTTPServer.BaseHTTPRequestHandler):
"""
A simple wrapper around BaseHTTPServer.BaseHTTPRequestHandler that handles a callback URL that accepts an
authorization token.
"""

def do_GET(self):
url = _urlparse.urlparse(self.path)
if url.path == self.server.redirect_path:
self.send_response(_StatusCodes.OK)
self.end_headers()
self.handle_login(dict(_urlparse.parse_qsl(url.query)))
else:
self.send_response(_StatusCodes.NOT_FOUND)

def handle_login(self, data):
self.server.handle_authorization_code(AuthorizationCode(data['code'], data['state']))


class OAuthHTTPServer(_BaseHTTPServer.HTTPServer):
"""
A simple wrapper around the BaseHTTPServer.HTTPServer implementation that binds an authorization_client for handling
authorization code callbacks.
"""
def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True,
redirect_path=None, queue=None):
_BaseHTTPServer.HTTPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate)
self._redirect_path = redirect_path
self._auth_code = None
self._queue = queue

@property
def redirect_path(self):
return self._redirect_path

def handle_authorization_code(self, auth_code):
self._queue.put(auth_code)


class Credentials(object):
def __init__(self, access_token=None):
self._access_token = access_token

@property
def access_token(self):
return self._access_token


class AuthorizationClient(object):
def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redirect_uri=None):
self._auth_endpoint = auth_endpoint
self._token_endpoint = token_endpoint
self._client_id = client_id
self._redirect_uri = redirect_uri
self._code_verifier = _generate_code_verifier()
code_challenge = _create_code_challenge(self._code_verifier)
self._code_challenge = code_challenge
state = _generate_state_parameter()
self._state = state
self._credentials = None
self._refresh_token = None
self._headers = {'content-type': "application/x-www-form-urlencoded"}
self._expired = False

self._params = {
"client_id": client_id, # This must match the Client ID of the OAuth application.
"response_type": "code", # Indicates the authorization code grant
"scope": "openid offline_access", # ensures that the /token endpoint returns an ID and refresh token
# callback location where the user-agent will be directed to.
"redirect_uri": self._redirect_uri,
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
}

# Prefer to use already-fetched token values when they've been set globally.
self._refresh_token = _keyring.get_password(_keyring_service_name, _keyring_refresh_token_storage_key)
access_token = _keyring.get_password(_keyring_service_name, _keyring_access_token_storage_key)
if access_token:
self._credentials = Credentials(access_token=access_token)
return

# In the absence of globally-set token values, initiate the token request flow
q = _Queue()
# First prepare the callback server in the background
server = self._create_callback_server(q)
server_process = _Process(target=server.handle_request)
server_process.start()

# Send the call to request the authorization code
self._request_authorization_code()

# Request the access token once the auth code has been received.
auth_code = q.get()
server_process.terminate()
self.request_access_token(auth_code)

def _create_callback_server(self, q):
server_url = _urlparse.urlparse(self._redirect_uri)
server_address = (server_url.hostname, server_url.port)
return OAuthHTTPServer(server_address, OAuthCallbackHandler, redirect_path=server_url.path, queue=q)

def _request_authorization_code(self):
scheme, netloc, path, _, _, _ = _urlparse.urlparse(self._auth_endpoint)
query = _urlencode(self._params)
endpoint = _urlparse.urlunparse((scheme, netloc, path, None, query, None))
_webbrowser.open_new_tab(endpoint)

def _initialize_credentials(self, auth_token_resp):

"""
The auth_token_resp body is of the form:
{
"access_token": "foo",
"refresh_token": "bar",
"id_token": "baz",
"token_type": "Bearer"
}
"""
response_body = auth_token_resp.json()
if "access_token" not in response_body:
raise ValueError('Expected "access_token" in response from oauth server')
if "refresh_token" in response_body:
self._refresh_token = response_body["refresh_token"]

access_token = response_body["access_token"]
refresh_token = response_body["refresh_token"]

_keyring.set_password(_keyring_service_name, _keyring_access_token_storage_key, access_token)
_keyring.set_password(_keyring_service_name, _keyring_refresh_token_storage_key, refresh_token)
self._credentials = Credentials(access_token=access_token)

def request_access_token(self, auth_code):
if self._state != auth_code.state:
raise ValueError("Unexpected state parameter [{}] passed".format(auth_code.state))
self._params.update({
"code": auth_code.code,
"code_verifier": self._code_verifier,
"grant_type": "authorization_code",
})
resp = _requests.post(
url=self._token_endpoint,
data=self._params,
headers=self._headers,
allow_redirects=False
)
if resp.status_code != _StatusCodes.OK:
# TODO: handle expected (?) error cases:
# https://auth0.com/docs/flows/guides/device-auth/call-api-device-auth#token-responses
raise Exception('Failed to request access token with response: [{}] {}'.format(
resp.status_code, resp.content))
self._initialize_credentials(resp)

def refresh_access_token(self):
if self._refresh_token is None:
raise ValueError("no refresh token available with which to refresh authorization credentials")

resp = _requests.post(
url=self._token_endpoint,
data={'grant_type': 'refresh_token',
'client_id': self._client_id,
'refresh_token': self._refresh_token},
headers=self._headers,
allow_redirects=False
)
if resp.status_code != _StatusCodes.OK:
self._expired = True
# In the absence of a successful response, assume the refresh token is expired. This should indicate
# to the caller that the AuthorizationClient is defunct and a new one needs to be re-initialized.

_keyring.delete_password(_keyring_service_name, _keyring_access_token_storage_key)
_keyring.delete_password(_keyring_service_name, _keyring_refresh_token_storage_key)
return
self._initialize_credentials(resp)

@property
def credentials(self):
"""
:return flytekit.clis.auth.auth.Credentials:
"""
return self._credentials

@property
def expired(self):
"""
:return bool:
"""
return self._expired
46 changes: 46 additions & 0 deletions flytekit/clis/auth/credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import absolute_import
from flytekit.clis.auth.auth import AuthorizationClient as _AuthorizationClient
from flytekit.clis.auth.discovery import DiscoveryClient as _DiscoveryClient

from flytekit.configuration.creds import (
REDIRECT_URI as _REDIRECT_URI,
CLIENT_ID as _CLIENT_ID
)
from flytekit.configuration.platform import URL as _URL, INSECURE as _INSECURE

try: # Python 3
import urllib.parse as _urlparse
except ImportError: # Python 2
import urlparse as _urlparse

# Default, well known-URI string used for fetching JSON metadata. See https://tools.ietf.org/html/rfc8414#section-3.
discovery_endpoint_path = ".well-known/oauth-authorization-server"


def _get_discovery_endpoint():
if _INSECURE.get():
return _urlparse.urljoin('http://{}/'.format(_URL.get()), discovery_endpoint_path)
return _urlparse.urljoin('https://{}/'.format(_URL.get()), discovery_endpoint_path)


# Lazy initialized authorization client singleton
_authorization_client = None


def get_client():
global _authorization_client
if _authorization_client is not None and not _authorization_client.expired:
return _authorization_client
authorization_endpoints = get_authorization_endpoints()

_authorization_client =\
_AuthorizationClient(redirect_uri=_REDIRECT_URI.get(), client_id=_CLIENT_ID.get(),
auth_endpoint=authorization_endpoints.auth_endpoint,
token_endpoint=authorization_endpoints.token_endpoint)
return _authorization_client


def get_authorization_endpoints():
discovery_endpoint = _get_discovery_endpoint()
discovery_client = _DiscoveryClient(discovery_url=discovery_endpoint)
return discovery_client.get_authorization_endpoints()
Loading

0 comments on commit 439a36d

Please sign in to comment.