Skip to content

Commit

Permalink
Feat: Enable flytekit to authenticate with proxy in front of FlyteA…
Browse files Browse the repository at this point in the history
…dmin (#1787)

* Introduce authenticator engine and make proxy auth work

Signed-off-by: Fabio Grätz <[email protected]>

* Use proxy authed session for client credentials flow

Signed-off-by: Fabio Grätz <[email protected]>

* Don't use authenticator engine but do proxy authentication via existing external command authenticator

Signed-off-by: Fabio Grätz <[email protected]>

* Add docstring to AuthenticationHTTPAdapter

Signed-off-by: Fabio Grätz <[email protected]>

* Address todo in docstring

Signed-off-by: Fabio Grätz <[email protected]>

* Create blank session if none provided

Signed-off-by: Fabio Grätz <[email protected]>

* Create blank session if none provided in get_token

Signed-off-by: Fabio Grätz <[email protected]>

* Refresh proxy creds in session when not existing without triggering 401

Signed-off-by: Fabio Grätz <[email protected]>

* Add test for get_session

Signed-off-by: Fabio Grätz <[email protected]>

* Move auth helper test into existing module

Signed-off-by: Fabio Grätz <[email protected]>

* Move auth helper test into existing module

Signed-off-by: Fabio Grätz <[email protected]>

* Add test for upgrade_channel_to_proxy_authenticated

Signed-off-by: Fabio Grätz <[email protected]>

* Auth helper tests without use of responses package

Signed-off-by: Fabio Grätz <[email protected]>

* Feat: Add plugin for generating GCP IAP ID tokens via external command (#1795)

* Add external command plugin to generate id tokens for identity aware proxy

Signed-off-by: Fabio Grätz <[email protected]>

* Retrieve desktop app client secret from gcp secret manager

Signed-off-by: Fabio Grätz <[email protected]>

* Remove comments

Signed-off-by: Fabio Grätz <[email protected]>

* Introduce a command group that allows adding a command to generate service account id tokens later

Signed-off-by: Fabio Grätz <[email protected]>

* Document how to use plugin and deploy Flyte with IAP

Signed-off-by: Fabio Grätz <[email protected]>

* Minor corrections README.md

Signed-off-by: Fabio Grätz <[email protected]>

---------

Signed-off-by: Fabio Grätz <[email protected]>
Co-authored-by: Fabio Grätz <[email protected]>
Signed-off-by: Fabio Grätz <[email protected]>

* Use proxy auth'ed session for device code auth flow

Signed-off-by: Fabio Grätz <[email protected]>

* Fix token client tests

Signed-off-by: Fabio Grätz <[email protected]>

* Make poll token endpoint test more specific

Signed-off-by: Fabio Grätz <[email protected]>

* Make test_client_creds_authenticator test work and more specific

Signed-off-by: Fabio Grätz <[email protected]>

* Make test_client_creds_authenticator_with_custom_scopes test work and more specific

Signed-off-by: Fabio Grätz <[email protected]>

* Implement subcommand to generate id tokens for service accounts

Signed-off-by: Fabio Graetz <[email protected]>

* Test id token generation from service accounts

Signed-off-by: Fabio Graetz <[email protected]>

* Fix plugin requirements

Signed-off-by: Fabio Graetz <[email protected]>

* Document usage of generate-service-account-id-token subcommand

Signed-off-by: Fabio Grätz <[email protected]>

* Document alternative ways to obtain service account id tokens

Signed-off-by: Fabio Grätz <[email protected]>

---------

Signed-off-by: Fabio Grätz <[email protected]>
Signed-off-by: Fabio Graetz <[email protected]>
Co-authored-by: Fabio Grätz <[email protected]>
Signed-off-by: Jeev B <[email protected]>
  • Loading branch information
2 people authored and jeevb committed Sep 20, 2023
1 parent 99abcb4 commit 3220a3e
Show file tree
Hide file tree
Showing 18 changed files with 1,155 additions and 65 deletions.
79 changes: 55 additions & 24 deletions flytekit/clients/auth/auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ def __init__(
redirect_uri: typing.Optional[str] = None,
endpoint_metadata: typing.Optional[EndpointMetadata] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
session: typing.Optional[_requests.Session] = None,
request_auth_code_params: typing.Optional[typing.Dict[str, str]] = None,
request_access_token_params: typing.Optional[typing.Dict[str, str]] = None,
refresh_access_token_params: typing.Optional[typing.Dict[str, str]] = None,
add_request_auth_code_params_to_request_access_token_params: typing.Optional[bool] = False,
):
"""
Create new AuthorizationClient
Expand All @@ -192,7 +197,9 @@ def __init__(
:param auth_endpoint: str endpoint where auth metadata can be found
:param token_endpoint: str endpoint to retrieve token from
:param scopes: list[str] oauth2 scopes
:param client_id
:param client_id: oauth2 client id
:param redirect_uri: oauth2 redirect uri
:param endpoint_metadata: EndpointMetadata object to control the rendering of the page on login successful or failure
:param verify: (optional) Either a boolean, in which case it controls whether we verify
the server's TLS certificate, or a string, in which case it must be a path
to a CA bundle to use. Defaults to ``True``. When set to
Expand All @@ -201,6 +208,15 @@ def __init__(
certificates, which will make your application vulnerable to
man-in-the-middle (MitM) attacks. Setting verify to ``False``
may be useful during local development or testing.
:param session: (optional) A custom requests.Session object to use for making HTTP requests.
If not provided, a new Session object will be created.
:param request_auth_code_params: (optional) dict of parameters to add to login uri opened in the browser
:param request_access_token_params: (optional) dict of parameters to add when exchanging the auth code for the access token
:param refresh_access_token_params: (optional) dict of parameters to add when refreshing the access token
:param add_request_auth_code_params_to_request_access_token_params: Whether to add the `request_auth_code_params` to
the parameters sent when exchanging the auth code for the access token. Defaults to False.
Required e.g. for the PKCE flow with flyteadmin.
Not required for e.g. the standard OAuth2 flow on GCP.
"""
self._endpoint = endpoint
self._auth_endpoint = auth_endpoint
Expand All @@ -213,15 +229,13 @@ def __init__(
self._client_id = client_id
self._scopes = scopes or []
self._redirect_uri = redirect_uri
self._code_verifier = _generate_code_verifier()
code_challenge = _create_code_challenge(self._code_verifier)
self._code_challenge = code_challenge
state = _generate_state_parameter()
self._state = state
self._verify = verify
self._headers = {"content-type": "application/x-www-form-urlencoded"}
self._session = session or _requests.Session()

self._params = {
self._request_auth_code_params = {
"client_id": client_id, # This must match the Client ID of the OAuth application.
"response_type": "code", # Indicates the authorization code grant
"scope": " ".join(s.strip("' ") for s in self._scopes).strip(
Expand All @@ -230,10 +244,18 @@ def __init__(
# callback location where the user-agent will be directed to.
"redirect_uri": self._redirect_uri,
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
}

if request_auth_code_params:
# Allow adding additional parameters to the request_auth_code_params
self._request_auth_code_params.update(request_auth_code_params)

self._request_access_token_params = request_access_token_params or {}
self._refresh_access_token_params = refresh_access_token_params or {}

if add_request_auth_code_params_to_request_access_token_params:
self._request_access_token_params.update(self._request_auth_code_params)

def __repr__(self):
return f"AuthorizationClient({self._auth_endpoint}, {self._token_endpoint}, {self._client_id}, {self._scopes}, {self._redirect_uri})"

Expand All @@ -249,7 +271,7 @@ def _create_callback_server(self):

def _request_authorization_code(self):
scheme, netloc, path, _, _, _ = _urlparse.urlparse(self._auth_endpoint)
query = _urlencode(self._params)
query = _urlencode(self._request_auth_code_params)
endpoint = _urlparse.urlunparse((scheme, netloc, path, None, query, None))
logging.debug(f"Requesting authorization code through {endpoint}")
_webbrowser.open_new_tab(endpoint)
Expand All @@ -262,33 +284,38 @@ def _credentials_from_response(self, auth_token_resp) -> Credentials:
"refresh_token": "bar",
"token_type": "Bearer"
}
Can additionally contain "expires_in" and "id_token" fields.
"""
response_body = auth_token_resp.json()
refresh_token = None
id_token = None
if "access_token" not in response_body:
raise ValueError('Expected "access_token" in response from oauth server')
if "refresh_token" in response_body:
refresh_token = response_body["refresh_token"]
if "expires_in" in response_body:
expires_in = response_body["expires_in"]
access_token = response_body["access_token"]
if "id_token" in response_body:
id_token = response_body["id_token"]

return Credentials(access_token, refresh_token, self._endpoint, expires_in=expires_in)
return Credentials(access_token, refresh_token, self._endpoint, expires_in=expires_in, id_token=id_token)

def _request_access_token(self, auth_code) -> Credentials:
if self._state != auth_code.state:
raise ValueError(f"Unexpected state parameter [{auth_code.state}] passed")
self._params.update(
{
"code": auth_code.code,
"code_verifier": self._code_verifier,
"grant_type": "authorization_code",
}
)

resp = _requests.post(
params = {
"code": auth_code.code,
"grant_type": "authorization_code",
}

params.update(self._request_access_token_params)

resp = self._session.post(
url=self._token_endpoint,
data=self._params,
data=params,
headers=self._headers,
allow_redirects=False,
verify=self._verify,
Expand Down Expand Up @@ -332,13 +359,17 @@ def refresh_access_token(self, credentials: Credentials) -> Credentials:
if credentials.refresh_token is None:
raise ValueError("no refresh token available with which to refresh authorization credentials")

resp = _requests.post(
data = {
"refresh_token": credentials.refresh_token,
"grant_type": "refresh_token",
"client_id": self._client_id,
}

data.update(self._refresh_access_token_params)

resp = self._session.post(
url=self._token_endpoint,
data={
"grant_type": "refresh_token",
"client_id": self._client_id,
"refresh_token": credentials.refresh_token,
},
data=data,
headers=self._headers,
allow_redirects=False,
verify=self._verify,
Expand Down
32 changes: 31 additions & 1 deletion flytekit/clients/auth/authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass

import click
import requests

from . import token_client
from .auth_client import AuthorizationClient
Expand Down Expand Up @@ -95,16 +96,24 @@ def __init__(
cfg_store: ClientConfigStore,
header_key: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
session: typing.Optional[requests.Session] = None,
):
"""
Initialize with default creds from KeyStore using the endpoint name
"""
super().__init__(endpoint, header_key, KeyringStore.retrieve(endpoint), verify=verify)
self._cfg_store = cfg_store
self._auth_client = None
self._session = session or requests.Session()

def _initialize_auth_client(self):
if not self._auth_client:

from .auth_client import _create_code_challenge, _generate_code_verifier

code_verifier = _generate_code_verifier()
code_challenge = _create_code_challenge(code_verifier)

cfg = self._cfg_store.get_client_config()
self._set_header_key(cfg.header_key)
self._auth_client = AuthorizationClient(
Expand All @@ -115,6 +124,16 @@ def _initialize_auth_client(self):
auth_endpoint=cfg.authorization_endpoint,
token_endpoint=cfg.token_endpoint,
verify=self._verify,
session=self._session,
request_auth_code_params={
"code_challenge": code_challenge,
"code_challenge_method": "S256",
},
request_access_token_params={
"code_verifier": code_verifier,
},
refresh_access_token_params={},
add_request_auth_code_params_to_request_access_token_params=True,
)

def refresh_credentials(self):
Expand Down Expand Up @@ -176,6 +195,7 @@ def __init__(
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
audience: typing.Optional[str] = None,
session: typing.Optional[requests.Session] = None,
):
if not client_id or not client_secret:
raise ValueError("Client ID and Client SECRET both are required.")
Expand All @@ -186,6 +206,7 @@ def __init__(
self._client_id = client_id
self._client_secret = client_secret
self._audience = audience or cfg.audience
self._session = session or requests.Session()
super().__init__(endpoint, cfg.header_key or header_key, http_proxy_url=http_proxy_url, verify=verify)

def refresh_credentials(self):
Expand All @@ -211,6 +232,7 @@ def refresh_credentials(self):
verify=self._verify,
scopes=scopes,
audience=audience,
session=self._session,
)

logging.info("Retrieved new token, expires in {}".format(expires_in))
Expand All @@ -234,6 +256,7 @@ def __init__(
audience: typing.Optional[str] = None,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
session: typing.Optional[requests.Session] = None,
):
self._audience = audience
cfg = cfg_store.get_client_config()
Expand All @@ -245,6 +268,7 @@ def __init__(
raise AuthenticationError(
"Device Authentication is not available on the Flyte backend / authentication server"
)
self._session = session or requests.Session()
super().__init__(
endpoint=endpoint,
header_key=header_key or cfg.header_key,
Expand All @@ -255,7 +279,13 @@ def __init__(

def refresh_credentials(self):
resp = token_client.get_device_code(
self._device_auth_endpoint, self._client_id, self._audience, self._scope, self._http_proxy_url, self._verify
self._device_auth_endpoint,
self._client_id,
self._audience,
self._scope,
self._http_proxy_url,
self._verify,
self._session,
)
text = f"To Authenticate, navigate in a browser to the following URL: {click.style(resp.verification_uri, fg='blue', underline=True)} and enter code: {click.style(resp.user_code, fg='blue')}"
click.secho(text)
Expand Down
30 changes: 22 additions & 8 deletions flytekit/clients/auth/keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass

import keyring as _keyring
from keyring.errors import NoKeyringError
from keyring.errors import NoKeyringError, PasswordDeleteError


@dataclass
Expand All @@ -16,6 +16,7 @@ class Credentials(object):
refresh_token: str = "na"
for_endpoint: str = "flyte-default"
expires_in: typing.Optional[int] = None
id_token: typing.Optional[str] = None


class KeyringStore:
Expand All @@ -25,20 +26,28 @@ class KeyringStore:

_access_token_key = "access_token"
_refresh_token_key = "refresh_token"
_id_token_key = "id_token"

@staticmethod
def store(credentials: Credentials) -> Credentials:
try:
_keyring.set_password(
credentials.for_endpoint,
KeyringStore._refresh_token_key,
credentials.refresh_token,
)
if credentials.refresh_token:
_keyring.set_password(
credentials.for_endpoint,
KeyringStore._refresh_token_key,
credentials.refresh_token,
)
_keyring.set_password(
credentials.for_endpoint,
KeyringStore._access_token_key,
credentials.access_token,
)
if credentials.id_token:
_keyring.set_password(
credentials.for_endpoint,
KeyringStore._id_token_key,
credentials.id_token,
)
except NoKeyringError as e:
logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}")
return credentials
Expand All @@ -48,18 +57,23 @@ def retrieve(for_endpoint: str) -> typing.Optional[Credentials]:
try:
refresh_token = _keyring.get_password(for_endpoint, KeyringStore._refresh_token_key)
access_token = _keyring.get_password(for_endpoint, KeyringStore._access_token_key)
id_token = _keyring.get_password(for_endpoint, KeyringStore._id_token_key)
except NoKeyringError as e:
logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}")
return None

if not access_token:
if not access_token and not id_token:
return None
return Credentials(access_token, refresh_token, for_endpoint)
return Credentials(access_token, refresh_token, for_endpoint, id_token=id_token)

@staticmethod
def delete(for_endpoint: str):
try:
_keyring.delete_password(for_endpoint, KeyringStore._access_token_key)
_keyring.delete_password(for_endpoint, KeyringStore._refresh_token_key)
try:
_keyring.delete_password(for_endpoint, KeyringStore._id_token_key)
except PasswordDeleteError as e:
logging.debug(f"Id token not found in key store, not deleting. Error: {e}")
except NoKeyringError as e:
logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}")
11 changes: 9 additions & 2 deletions flytekit/clients/auth/token_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def get_token(
grant_type: GrantType = GrantType.CLIENT_CREDS,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
session: typing.Optional[requests.Session] = None,
) -> typing.Tuple[str, int]:
"""
:rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration
Expand All @@ -103,7 +104,10 @@ def get_token(
body["audience"] = audience

proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None
response = requests.post(token_endpoint, data=body, headers=headers, proxies=proxies, verify=verify)

if not session:
session = requests.Session()
response = session.post(token_endpoint, data=body, headers=headers, proxies=proxies, verify=verify)

if not response.ok:
j = response.json()
Expand All @@ -125,6 +129,7 @@ def get_device_code(
scope: typing.Optional[typing.List[str]] = None,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
session: typing.Optional[requests.Session] = None,
) -> DeviceCodeResponse:
"""
Retrieves the device Authentication code that can be done to authenticate the request using a browser on a
Expand All @@ -133,7 +138,9 @@ def get_device_code(
_scope = " ".join(scope) if scope is not None else ""
payload = {"client_id": client_id, "scope": _scope, "audience": audience}
proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None
resp = requests.post(device_auth_endpoint, payload, proxies=proxies, verify=verify)
if not session:
session = requests.Session()
resp = session.post(device_auth_endpoint, payload, proxies=proxies, verify=verify)
if not resp.ok:
raise AuthenticationError(f"Unable to retrieve Device Authentication Code for {payload}, Reason {resp.reason}")
return DeviceCodeResponse.from_json_response(resp.json())
Expand Down
Loading

0 comments on commit 3220a3e

Please sign in to comment.