From 3220a3ecca888cc5a0118607a86b18350ac74e5d Mon Sep 17 00:00:00 2001 From: "Fabio M. Graetz, Ph.D" Date: Wed, 20 Sep 2023 20:52:53 +0200 Subject: [PATCH] Feat: Enable `flytekit` to authenticate with proxy in front of FlyteAdmin (#1787) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Introduce authenticator engine and make proxy auth work Signed-off-by: Fabio Grätz * Use proxy authed session for client credentials flow Signed-off-by: Fabio Grätz * Don't use authenticator engine but do proxy authentication via existing external command authenticator Signed-off-by: Fabio Grätz * Add docstring to AuthenticationHTTPAdapter Signed-off-by: Fabio Grätz * Address todo in docstring Signed-off-by: Fabio Grätz * Create blank session if none provided Signed-off-by: Fabio Grätz * Create blank session if none provided in get_token Signed-off-by: Fabio Grätz * Refresh proxy creds in session when not existing without triggering 401 Signed-off-by: Fabio Grätz * Add test for get_session Signed-off-by: Fabio Grätz * Move auth helper test into existing module Signed-off-by: Fabio Grätz * Move auth helper test into existing module Signed-off-by: Fabio Grätz * Add test for upgrade_channel_to_proxy_authenticated Signed-off-by: Fabio Grätz * Auth helper tests without use of responses package Signed-off-by: Fabio Grätz * Feat: Add plugin for generating GCP IAP ID tokens via external command (#1795) * Add external command plugin to generate id tokens for identity aware proxy Signed-off-by: Fabio Grätz * Retrieve desktop app client secret from gcp secret manager Signed-off-by: Fabio Grätz * Remove comments Signed-off-by: Fabio Grätz * Introduce a command group that allows adding a command to generate service account id tokens later Signed-off-by: Fabio Grätz * Document how to use plugin and deploy Flyte with IAP Signed-off-by: Fabio Grätz * Minor corrections README.md Signed-off-by: Fabio Grätz --------- Signed-off-by: Fabio Grätz Co-authored-by: Fabio Grätz Signed-off-by: Fabio Grätz * Use proxy auth'ed session for device code auth flow Signed-off-by: Fabio Grätz * Fix token client tests Signed-off-by: Fabio Grätz * Make poll token endpoint test more specific Signed-off-by: Fabio Grätz * Make test_client_creds_authenticator test work and more specific Signed-off-by: Fabio Grätz * Make test_client_creds_authenticator_with_custom_scopes test work and more specific Signed-off-by: Fabio Grätz * Implement subcommand to generate id tokens for service accounts Signed-off-by: Fabio Graetz * Test id token generation from service accounts Signed-off-by: Fabio Graetz * Fix plugin requirements Signed-off-by: Fabio Graetz * Document usage of generate-service-account-id-token subcommand Signed-off-by: Fabio Grätz * Document alternative ways to obtain service account id tokens Signed-off-by: Fabio Grätz --------- Signed-off-by: Fabio Grätz Signed-off-by: Fabio Graetz Co-authored-by: Fabio Grätz Signed-off-by: Jeev B --- flytekit/clients/auth/auth_client.py | 79 ++-- flytekit/clients/auth/authenticator.py | 32 +- flytekit/clients/auth/keyring.py | 30 +- flytekit/clients/auth/token_client.py | 11 +- flytekit/clients/auth_helper.py | 92 +++- .../clients/grpc_utils/auth_interceptor.py | 2 +- flytekit/clients/raw.py | 11 +- flytekit/configuration/__init__.py | 3 + flytekit/configuration/internal.py | 7 + .../flytekit-identity-aware-proxy/README.md | 404 ++++++++++++++++++ .../identity_aware_proxy/__init__.py | 0 .../identity_aware_proxy/cli.py | 248 +++++++++++ .../flytekit-identity-aware-proxy/setup.py | 43 ++ .../tests/__init__.py | 0 .../tests/test_flytekitplugins_iap.py | 146 +++++++ .../unit/clients/auth/test_authenticator.py | 35 +- .../unit/clients/auth/test_token_client.py | 34 +- .../flytekit/unit/clients/test_auth_helper.py | 43 +- 18 files changed, 1155 insertions(+), 65 deletions(-) create mode 100644 plugins/flytekit-identity-aware-proxy/README.md create mode 100644 plugins/flytekit-identity-aware-proxy/flytekitplugins/identity_aware_proxy/__init__.py create mode 100644 plugins/flytekit-identity-aware-proxy/flytekitplugins/identity_aware_proxy/cli.py create mode 100644 plugins/flytekit-identity-aware-proxy/setup.py create mode 100644 plugins/flytekit-identity-aware-proxy/tests/__init__.py create mode 100644 plugins/flytekit-identity-aware-proxy/tests/test_flytekitplugins_iap.py diff --git a/flytekit/clients/auth/auth_client.py b/flytekit/clients/auth/auth_client.py index ec1fd4d3e1..29a995ca89 100644 --- a/flytekit/clients/auth/auth_client.py +++ b/flytekit/clients/auth/auth_client.py @@ -184,6 +184,11 @@ def __init__( redirect_uri: typing.Optional[str] = None, endpoint_metadata: typing.Optional[EndpointMetadata] = None, verify: typing.Optional[typing.Union[bool, str]] = None, + session: typing.Optional[_requests.Session] = None, + request_auth_code_params: typing.Optional[typing.Dict[str, str]] = None, + request_access_token_params: typing.Optional[typing.Dict[str, str]] = None, + refresh_access_token_params: typing.Optional[typing.Dict[str, str]] = None, + add_request_auth_code_params_to_request_access_token_params: typing.Optional[bool] = False, ): """ Create new AuthorizationClient @@ -192,7 +197,9 @@ def __init__( :param auth_endpoint: str endpoint where auth metadata can be found :param token_endpoint: str endpoint to retrieve token from :param scopes: list[str] oauth2 scopes - :param client_id + :param client_id: oauth2 client id + :param redirect_uri: oauth2 redirect uri + :param endpoint_metadata: EndpointMetadata object to control the rendering of the page on login successful or failure :param verify: (optional) Either a boolean, in which case it controls whether we verify the server's TLS certificate, or a string, in which case it must be a path to a CA bundle to use. Defaults to ``True``. When set to @@ -201,6 +208,15 @@ def __init__( certificates, which will make your application vulnerable to man-in-the-middle (MitM) attacks. Setting verify to ``False`` may be useful during local development or testing. + :param session: (optional) A custom requests.Session object to use for making HTTP requests. + If not provided, a new Session object will be created. + :param request_auth_code_params: (optional) dict of parameters to add to login uri opened in the browser + :param request_access_token_params: (optional) dict of parameters to add when exchanging the auth code for the access token + :param refresh_access_token_params: (optional) dict of parameters to add when refreshing the access token + :param add_request_auth_code_params_to_request_access_token_params: Whether to add the `request_auth_code_params` to + the parameters sent when exchanging the auth code for the access token. Defaults to False. + Required e.g. for the PKCE flow with flyteadmin. + Not required for e.g. the standard OAuth2 flow on GCP. """ self._endpoint = endpoint self._auth_endpoint = auth_endpoint @@ -213,15 +229,13 @@ def __init__( self._client_id = client_id self._scopes = scopes or [] self._redirect_uri = redirect_uri - self._code_verifier = _generate_code_verifier() - code_challenge = _create_code_challenge(self._code_verifier) - self._code_challenge = code_challenge state = _generate_state_parameter() self._state = state self._verify = verify self._headers = {"content-type": "application/x-www-form-urlencoded"} + self._session = session or _requests.Session() - self._params = { + self._request_auth_code_params = { "client_id": client_id, # This must match the Client ID of the OAuth application. "response_type": "code", # Indicates the authorization code grant "scope": " ".join(s.strip("' ") for s in self._scopes).strip( @@ -230,10 +244,18 @@ def __init__( # callback location where the user-agent will be directed to. "redirect_uri": self._redirect_uri, "state": state, - "code_challenge": code_challenge, - "code_challenge_method": "S256", } + if request_auth_code_params: + # Allow adding additional parameters to the request_auth_code_params + self._request_auth_code_params.update(request_auth_code_params) + + self._request_access_token_params = request_access_token_params or {} + self._refresh_access_token_params = refresh_access_token_params or {} + + if add_request_auth_code_params_to_request_access_token_params: + self._request_access_token_params.update(self._request_auth_code_params) + def __repr__(self): return f"AuthorizationClient({self._auth_endpoint}, {self._token_endpoint}, {self._client_id}, {self._scopes}, {self._redirect_uri})" @@ -249,7 +271,7 @@ def _create_callback_server(self): def _request_authorization_code(self): scheme, netloc, path, _, _, _ = _urlparse.urlparse(self._auth_endpoint) - query = _urlencode(self._params) + query = _urlencode(self._request_auth_code_params) endpoint = _urlparse.urlunparse((scheme, netloc, path, None, query, None)) logging.debug(f"Requesting authorization code through {endpoint}") _webbrowser.open_new_tab(endpoint) @@ -262,9 +284,12 @@ def _credentials_from_response(self, auth_token_resp) -> Credentials: "refresh_token": "bar", "token_type": "Bearer" } + + Can additionally contain "expires_in" and "id_token" fields. """ response_body = auth_token_resp.json() refresh_token = None + id_token = None if "access_token" not in response_body: raise ValueError('Expected "access_token" in response from oauth server') if "refresh_token" in response_body: @@ -272,23 +297,25 @@ def _credentials_from_response(self, auth_token_resp) -> Credentials: if "expires_in" in response_body: expires_in = response_body["expires_in"] access_token = response_body["access_token"] + if "id_token" in response_body: + id_token = response_body["id_token"] - return Credentials(access_token, refresh_token, self._endpoint, expires_in=expires_in) + return Credentials(access_token, refresh_token, self._endpoint, expires_in=expires_in, id_token=id_token) def _request_access_token(self, auth_code) -> Credentials: if self._state != auth_code.state: raise ValueError(f"Unexpected state parameter [{auth_code.state}] passed") - self._params.update( - { - "code": auth_code.code, - "code_verifier": self._code_verifier, - "grant_type": "authorization_code", - } - ) - resp = _requests.post( + params = { + "code": auth_code.code, + "grant_type": "authorization_code", + } + + params.update(self._request_access_token_params) + + resp = self._session.post( url=self._token_endpoint, - data=self._params, + data=params, headers=self._headers, allow_redirects=False, verify=self._verify, @@ -332,13 +359,17 @@ def refresh_access_token(self, credentials: Credentials) -> Credentials: if credentials.refresh_token is None: raise ValueError("no refresh token available with which to refresh authorization credentials") - resp = _requests.post( + data = { + "refresh_token": credentials.refresh_token, + "grant_type": "refresh_token", + "client_id": self._client_id, + } + + data.update(self._refresh_access_token_params) + + resp = self._session.post( url=self._token_endpoint, - data={ - "grant_type": "refresh_token", - "client_id": self._client_id, - "refresh_token": credentials.refresh_token, - }, + data=data, headers=self._headers, allow_redirects=False, verify=self._verify, diff --git a/flytekit/clients/auth/authenticator.py b/flytekit/clients/auth/authenticator.py index b2b82831c7..0d9ee6ef95 100644 --- a/flytekit/clients/auth/authenticator.py +++ b/flytekit/clients/auth/authenticator.py @@ -5,6 +5,7 @@ from dataclasses import dataclass import click +import requests from . import token_client from .auth_client import AuthorizationClient @@ -95,6 +96,7 @@ def __init__( cfg_store: ClientConfigStore, header_key: typing.Optional[str] = None, verify: typing.Optional[typing.Union[bool, str]] = None, + session: typing.Optional[requests.Session] = None, ): """ Initialize with default creds from KeyStore using the endpoint name @@ -102,9 +104,16 @@ def __init__( super().__init__(endpoint, header_key, KeyringStore.retrieve(endpoint), verify=verify) self._cfg_store = cfg_store self._auth_client = None + self._session = session or requests.Session() def _initialize_auth_client(self): if not self._auth_client: + + from .auth_client import _create_code_challenge, _generate_code_verifier + + code_verifier = _generate_code_verifier() + code_challenge = _create_code_challenge(code_verifier) + cfg = self._cfg_store.get_client_config() self._set_header_key(cfg.header_key) self._auth_client = AuthorizationClient( @@ -115,6 +124,16 @@ def _initialize_auth_client(self): auth_endpoint=cfg.authorization_endpoint, token_endpoint=cfg.token_endpoint, verify=self._verify, + session=self._session, + request_auth_code_params={ + "code_challenge": code_challenge, + "code_challenge_method": "S256", + }, + request_access_token_params={ + "code_verifier": code_verifier, + }, + refresh_access_token_params={}, + add_request_auth_code_params_to_request_access_token_params=True, ) def refresh_credentials(self): @@ -176,6 +195,7 @@ def __init__( http_proxy_url: typing.Optional[str] = None, verify: typing.Optional[typing.Union[bool, str]] = None, audience: typing.Optional[str] = None, + session: typing.Optional[requests.Session] = None, ): if not client_id or not client_secret: raise ValueError("Client ID and Client SECRET both are required.") @@ -186,6 +206,7 @@ def __init__( self._client_id = client_id self._client_secret = client_secret self._audience = audience or cfg.audience + self._session = session or requests.Session() super().__init__(endpoint, cfg.header_key or header_key, http_proxy_url=http_proxy_url, verify=verify) def refresh_credentials(self): @@ -211,6 +232,7 @@ def refresh_credentials(self): verify=self._verify, scopes=scopes, audience=audience, + session=self._session, ) logging.info("Retrieved new token, expires in {}".format(expires_in)) @@ -234,6 +256,7 @@ def __init__( audience: typing.Optional[str] = None, http_proxy_url: typing.Optional[str] = None, verify: typing.Optional[typing.Union[bool, str]] = None, + session: typing.Optional[requests.Session] = None, ): self._audience = audience cfg = cfg_store.get_client_config() @@ -245,6 +268,7 @@ def __init__( raise AuthenticationError( "Device Authentication is not available on the Flyte backend / authentication server" ) + self._session = session or requests.Session() super().__init__( endpoint=endpoint, header_key=header_key or cfg.header_key, @@ -255,7 +279,13 @@ def __init__( def refresh_credentials(self): resp = token_client.get_device_code( - self._device_auth_endpoint, self._client_id, self._audience, self._scope, self._http_proxy_url, self._verify + self._device_auth_endpoint, + self._client_id, + self._audience, + self._scope, + self._http_proxy_url, + self._verify, + self._session, ) text = f"To Authenticate, navigate in a browser to the following URL: {click.style(resp.verification_uri, fg='blue', underline=True)} and enter code: {click.style(resp.user_code, fg='blue')}" click.secho(text) diff --git a/flytekit/clients/auth/keyring.py b/flytekit/clients/auth/keyring.py index 79f5e86c68..2d4b4488f0 100644 --- a/flytekit/clients/auth/keyring.py +++ b/flytekit/clients/auth/keyring.py @@ -3,7 +3,7 @@ from dataclasses import dataclass import keyring as _keyring -from keyring.errors import NoKeyringError +from keyring.errors import NoKeyringError, PasswordDeleteError @dataclass @@ -16,6 +16,7 @@ class Credentials(object): refresh_token: str = "na" for_endpoint: str = "flyte-default" expires_in: typing.Optional[int] = None + id_token: typing.Optional[str] = None class KeyringStore: @@ -25,20 +26,28 @@ class KeyringStore: _access_token_key = "access_token" _refresh_token_key = "refresh_token" + _id_token_key = "id_token" @staticmethod def store(credentials: Credentials) -> Credentials: try: - _keyring.set_password( - credentials.for_endpoint, - KeyringStore._refresh_token_key, - credentials.refresh_token, - ) + if credentials.refresh_token: + _keyring.set_password( + credentials.for_endpoint, + KeyringStore._refresh_token_key, + credentials.refresh_token, + ) _keyring.set_password( credentials.for_endpoint, KeyringStore._access_token_key, credentials.access_token, ) + if credentials.id_token: + _keyring.set_password( + credentials.for_endpoint, + KeyringStore._id_token_key, + credentials.id_token, + ) except NoKeyringError as e: logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}") return credentials @@ -48,18 +57,23 @@ def retrieve(for_endpoint: str) -> typing.Optional[Credentials]: try: refresh_token = _keyring.get_password(for_endpoint, KeyringStore._refresh_token_key) access_token = _keyring.get_password(for_endpoint, KeyringStore._access_token_key) + id_token = _keyring.get_password(for_endpoint, KeyringStore._id_token_key) except NoKeyringError as e: logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}") return None - if not access_token: + if not access_token and not id_token: return None - return Credentials(access_token, refresh_token, for_endpoint) + return Credentials(access_token, refresh_token, for_endpoint, id_token=id_token) @staticmethod def delete(for_endpoint: str): try: _keyring.delete_password(for_endpoint, KeyringStore._access_token_key) _keyring.delete_password(for_endpoint, KeyringStore._refresh_token_key) + try: + _keyring.delete_password(for_endpoint, KeyringStore._id_token_key) + except PasswordDeleteError as e: + logging.debug(f"Id token not found in key store, not deleting. Error: {e}") except NoKeyringError as e: logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}") diff --git a/flytekit/clients/auth/token_client.py b/flytekit/clients/auth/token_client.py index e5eae32ed7..4584866b21 100644 --- a/flytekit/clients/auth/token_client.py +++ b/flytekit/clients/auth/token_client.py @@ -78,6 +78,7 @@ def get_token( grant_type: GrantType = GrantType.CLIENT_CREDS, http_proxy_url: typing.Optional[str] = None, verify: typing.Optional[typing.Union[bool, str]] = None, + session: typing.Optional[requests.Session] = None, ) -> typing.Tuple[str, int]: """ :rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration @@ -103,7 +104,10 @@ def get_token( body["audience"] = audience proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None - response = requests.post(token_endpoint, data=body, headers=headers, proxies=proxies, verify=verify) + + if not session: + session = requests.Session() + response = session.post(token_endpoint, data=body, headers=headers, proxies=proxies, verify=verify) if not response.ok: j = response.json() @@ -125,6 +129,7 @@ def get_device_code( scope: typing.Optional[typing.List[str]] = None, http_proxy_url: typing.Optional[str] = None, verify: typing.Optional[typing.Union[bool, str]] = None, + session: typing.Optional[requests.Session] = None, ) -> DeviceCodeResponse: """ Retrieves the device Authentication code that can be done to authenticate the request using a browser on a @@ -133,7 +138,9 @@ def get_device_code( _scope = " ".join(scope) if scope is not None else "" payload = {"client_id": client_id, "scope": _scope, "audience": audience} proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None - resp = requests.post(device_auth_endpoint, payload, proxies=proxies, verify=verify) + if not session: + session = requests.Session() + resp = session.post(device_auth_endpoint, payload, proxies=proxies, verify=verify) if not resp.ok: raise AuthenticationError(f"Unable to retrieve Device Authentication Code for {payload}, Reason {resp.reason}") return DeviceCodeResponse.from_json_response(resp.json()) diff --git a/flytekit/clients/auth_helper.py b/flytekit/clients/auth_helper.py index bdff000623..75bc52378e 100644 --- a/flytekit/clients/auth_helper.py +++ b/flytekit/clients/auth_helper.py @@ -1,7 +1,9 @@ import logging import ssl +from http import HTTPStatus import grpc +import requests from flyteidl.service.auth_pb2 import OAuth2MetadataRequest, PublicClientAuthConfigRequest from flyteidl.service.auth_pb2_grpc import AuthMetadataServiceStub from OpenSSL import crypto @@ -66,8 +68,10 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth elif cfg.ca_cert_file_path: verify = cfg.ca_cert_file_path + session = get_session(cfg) + if cfg_auth == AuthType.STANDARD or cfg_auth == AuthType.PKCE: - return PKCEAuthenticator(cfg.endpoint, cfg_store, verify=verify) + return PKCEAuthenticator(cfg.endpoint, cfg_store, verify=verify, session=session) elif cfg_auth == AuthType.BASIC or cfg_auth == AuthType.CLIENT_CREDENTIALS or cfg_auth == AuthType.CLIENTSECRET: return ClientCredentialsAuthenticator( endpoint=cfg.endpoint, @@ -78,6 +82,7 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth audience=cfg.audience, http_proxy_url=cfg.http_proxy_url, verify=verify, + session=session, ) elif cfg_auth == AuthType.EXTERNAL_PROCESS or cfg_auth == AuthType.EXTERNALCOMMAND: client_cfg = None @@ -94,6 +99,7 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth audience=cfg.audience, http_proxy_url=cfg.http_proxy_url, verify=verify, + session=session, ) else: raise ValueError( @@ -101,6 +107,28 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth ) +def get_proxy_authenticator(cfg: PlatformConfig) -> Authenticator: + return CommandAuthenticator( + command=cfg.proxy_command, + header_key="proxy-authorization", + ) + + +def upgrade_channel_to_proxy_authenticated(cfg: PlatformConfig, in_channel: grpc.Channel) -> grpc.Channel: + """ + If activated in the platform config, given a grpc.Channel, preferrably a secure channel, it returns a composed + channel that uses Interceptor to perform authentication with a proxy infront of Flyte + :param cfg: PlatformConfig + :param in_channel: grpc.Channel Precreated channel + :return: grpc.Channel. New composite channel + """ + if cfg.proxy_command: + proxy_authenticator = get_proxy_authenticator(cfg) + return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(proxy_authenticator)) + else: + return in_channel + + def upgrade_channel_to_authenticated(cfg: PlatformConfig, in_channel: grpc.Channel) -> grpc.Channel: """ Given a grpc.Channel, preferrably a secure channel, it returns a composed channel that uses Interceptor to @@ -122,6 +150,7 @@ def get_authenticated_channel(cfg: PlatformConfig) -> grpc.Channel: if cfg.insecure else grpc.secure_channel(cfg.endpoint, grpc.ssl_channel_credentials()) ) # noqa + channel = upgrade_channel_to_proxy_authenticated(cfg, channel) return upgrade_channel_to_authenticated(cfg, channel) @@ -213,3 +242,64 @@ def wrap_exceptions_channel(cfg: PlatformConfig, in_channel: grpc.Channel) -> gr :return: grpc.Channel """ return grpc.intercept_channel(in_channel, RetryExceptionWrapperInterceptor(max_retries=cfg.rpc_retries)) + + +class AuthenticationHTTPAdapter(requests.adapters.HTTPAdapter): + """ + A custom HTTPAdapter that adds authentication headers to requests of a session. + """ + + def __init__(self, authenticator, *args, **kwargs): + self.authenticator = authenticator + super().__init__(*args, **kwargs) + + def add_auth_header(self, request): + """ + Adds authentication headers to the request. + :param request: The request object to add headers to. + """ + if self.authenticator.get_credentials() is None: + self.authenticator.refresh_credentials() + + auth_header_key, auth_header_val = self.authenticator.fetch_grpc_call_auth_metadata() + request.headers[auth_header_key] = auth_header_val + + def send(self, request, *args, **kwargs): + """ + Sends the request with added authentication headers. + If the response returns a 401 status code, refreshes the credentials and retries the request. + :param request: The request object to send. + :return: The response object. + """ + self.add_auth_header(request) + response = super().send(request, *args, **kwargs) + if response.status_code == HTTPStatus.UNAUTHORIZED: + self.authenticator.refresh_credentials() + self.add_auth_header(request) + response = super().send(request, *args, **kwargs) + return response + + +def upgrade_session_to_proxy_authenticated(cfg: PlatformConfig, session: requests.Session) -> requests.Session: + """ + Given a requests.Session, it returns a new session that uses a custom HTTPAdapter to + perform authentication with a proxy infront of Flyte + + :param cfg: PlatformConfig + :param session: requests.Session Precreated session + :return: requests.Session. New session with custom HTTPAdapter mounted + """ + proxy_authenticator = get_proxy_authenticator(cfg) + adapter = AuthenticationHTTPAdapter(proxy_authenticator) + + session.mount("http://", adapter) + session.mount("https://", adapter) + return session + + +def get_session(cfg: PlatformConfig, **kwargs) -> requests.Session: + """Return a new session for the given platform config.""" + session = requests.Session() + if cfg.proxy_command: + session = upgrade_session_to_proxy_authenticated(cfg, session) + return session diff --git a/flytekit/clients/grpc_utils/auth_interceptor.py b/flytekit/clients/grpc_utils/auth_interceptor.py index 53f178a9a9..e467801a77 100644 --- a/flytekit/clients/grpc_utils/auth_interceptor.py +++ b/flytekit/clients/grpc_utils/auth_interceptor.py @@ -61,7 +61,7 @@ def intercept_unary_unary( fut: grpc.Future = continuation(updated_call_details, request) e = fut.exception() if e: - if e.code() == grpc.StatusCode.UNAUTHENTICATED: + if e.code() == grpc.StatusCode.UNAUTHENTICATED or e.code() == grpc.StatusCode.UNKNOWN: self._authenticator.refresh_credentials() updated_call_details = self._call_details_with_auth_metadata(client_call_details) return continuation(updated_call_details, request) diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 836d5ffa3b..6cb80d4b8f 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -11,7 +11,12 @@ from flyteidl.service import signal_pb2_grpc as signal_service from flyteidl.service.dataproxy_pb2_grpc import DataProxyServiceStub -from flytekit.clients.auth_helper import get_channel, upgrade_channel_to_authenticated, wrap_exceptions_channel +from flytekit.clients.auth_helper import ( + get_channel, + upgrade_channel_to_authenticated, + upgrade_channel_to_proxy_authenticated, + wrap_exceptions_channel, +) from flytekit.configuration import PlatformConfig from flytekit.loggers import cli_logger @@ -41,7 +46,9 @@ def __init__(self, cfg: PlatformConfig, **kwargs): insecure: if insecure is desired """ self._cfg = cfg - self._channel = wrap_exceptions_channel(cfg, upgrade_channel_to_authenticated(cfg, get_channel(cfg))) + self._channel = wrap_exceptions_channel( + cfg, upgrade_channel_to_authenticated(cfg, upgrade_channel_to_proxy_authenticated(cfg, get_channel(cfg))) + ) self._stub = _admin_service.AdminServiceStub(self._channel) self._signal = signal_service.SignalServiceStub(self._channel) self._dataproxy_stub = dataproxy_service.DataProxyServiceStub(self._channel) diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index 8e5ccf2fe2..0058bdc551 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -373,6 +373,7 @@ class PlatformConfig(object): :param insecure_skip_verify: Whether to skip SSL certificate verification :param console_endpoint: endpoint for console if different from Flyte backend :param command: This command is executed to return a token using an external process + :param proxy_command: This command is executed to return a token for proxy authorization using an external process :param client_id: This is the public identifier for the app which handles authorization for a Flyte deployment. More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/. :param client_credentials_secret: Used for service auth, which is automatically called during pyflyte. This will @@ -390,6 +391,7 @@ class PlatformConfig(object): ca_cert_file_path: typing.Optional[str] = None console_endpoint: typing.Optional[str] = None command: typing.Optional[typing.List[str]] = None + proxy_command: typing.Optional[typing.List[str]] = None client_id: typing.Optional[str] = None client_credentials_secret: typing.Optional[str] = None scopes: List[str] = field(default_factory=list) @@ -413,6 +415,7 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None ) kwargs = set_if_exists(kwargs, "ca_cert_file_path", _internal.Platform.CA_CERT_FILE_PATH.read(config_file)) kwargs = set_if_exists(kwargs, "command", _internal.Credentials.COMMAND.read(config_file)) + kwargs = set_if_exists(kwargs, "proxy_command", _internal.Credentials.PROXY_COMMAND.read(config_file)) kwargs = set_if_exists(kwargs, "client_id", _internal.Credentials.CLIENT_ID.read(config_file)) kwargs = set_if_exists( kwargs, "client_credentials_secret", _internal.Credentials.CLIENT_CREDENTIALS_SECRET.read(config_file) diff --git a/flytekit/configuration/internal.py b/flytekit/configuration/internal.py index f34321f57b..9d1980c450 100644 --- a/flytekit/configuration/internal.py +++ b/flytekit/configuration/internal.py @@ -64,6 +64,13 @@ class Credentials(object): This command is executed to return a token using an external process. """ + PROXY_COMMAND = ConfigEntry( + LegacyConfigEntry(SECTION, "proxy_command", list), YamlConfigEntry("admin.proxyCommand", list) + ) + """ + This command is executed to return a token for authorization with a proxy in front of Flyte using an external process. + """ + CLIENT_ID = ConfigEntry(LegacyConfigEntry(SECTION, "client_id"), YamlConfigEntry("admin.clientId")) """ This is the public identifier for the app which handles authorization for a Flyte deployment. diff --git a/plugins/flytekit-identity-aware-proxy/README.md b/plugins/flytekit-identity-aware-proxy/README.md new file mode 100644 index 0000000000..c6c631707c --- /dev/null +++ b/plugins/flytekit-identity-aware-proxy/README.md @@ -0,0 +1,404 @@ +# Flytekit Identity Aware Proxy + +[GCP Identity Aware Proxy (IAP)](https://cloud.google.com/iap) is a managed Google Cloud Platform (GCP) service that makes it easy to protect applications deployed on GCP by verifying user identity and using context to determine whether a user should be granted access. Because requests to applications protected with IAP first have to pass IAP before they can reach the protected backends, IAP provides a convenient way to implement a zero-trust access model. + +This flytekit plugin allows users to generate ID tokens via an external command for use with Flyte deployments protected with IAP. A step by step guide to protect a Flyte deployment with IAP is provided as well. + +**Disclaimer: Do not choose this deployment path with the goal of *a* Flyte deployment configured with authentication on GCP. The deployment is more involved than the standard Flyte GCP deployment. Follow this guide if your organization has a security policy that requires the use of GCP Identity Aware Proxy.** + +## Configuring the token generation CLI provided by this plugin + +1. Install this plugin via `pip install flytekitplugins-identity-aware-proxy`. + + Verify the installation with `flyte-iap --help`. + +2. Create OAuth 2.0 credentials for both the token generation CLI and for IAP. + 1. [Desktop OAauth credentials](https://cloud.google.com/iap/docs/authentication-howto#authenticating_from_a_desktop_app) for this CLI: + + In the GCP cloud console navigate to *"Apis & Services" / "Credentials"* click *"Create Credentials"*, select "*OAuth Client ID*", and finally choose *“Desktop App”*. + + Note the client id and client secret. + + 2. Follow the instructions to [activate IAP](https://cloud.google.com/iap/docs/enabling-kubernetes-howto#enabling_iap) in your project and cluster. In the process you will create web application type OAuth credentials for IAP (similar as done above for the desktop application type credentials). Again, note the client id and client secret. Don't proceed with the instructions to create the Kubernetes secret for these credentials and the backend config yet, this is done in the deployment guide below. Stop when you have the client id and secret. + + Note: In case you have an existing [Flyte deployment with auth configured](https://docs.flyte.org/en/latest/deployment/configuration/auth_setup.html#apply-oidc-configuration), you likely already have web application type OAuth credentials. You can reuse those credentials for Flyte's IAP. + +3. The token generation CLI provided by this plugin requires 1) the desktop application type client id and client secret to issue an ID token for IAP as well as 2) the client id (not the secret) of the web app type credentials that will be used by IAP (as the audience of the token). + + The desktop client secret needs to be kept secret. Therefore, create a GCP secret manager secret with the desktop client secret. + + Note the name of the secret and the id of the GCP project containing the secret. + + (You will have to grant users that will use the token generation CLI access to the secret.) + +4. Test the token generation CLI: + + ```console + flyte-iap generate-user-id-token \ + --desktop_client_id < fill in desktop client id> \ + --desktop_client_secret_gcp_secret_name \ + --webapp_client_id < fill in the web app client id> \ + --project < fill in the gcp project id where the secret was saved > + ``` + + A browser window should open, asking you to login with your GCP account. Then, a succesful log in should be confirmed with *"Successfully logged into accounts.google.com"*. + + Finally, the token beginning with `eyJhbG..."` should be printed to the console. + + You can decode the token with: + + ```console + jq -R 'split(".") | select(length > 0) | .[0],.[1] | @base64d | fromjson' <<< "eyJhbG..." + ``` + + The token should be issued by `"https://accounts.google.com"`, should contain your email, and should have the desktop client id set as `"azp"` and the web app client id set as `"aud"` (audience). + +5. Configure proxy authorization with this CLI in `~/.flyte/config.yaml`: + + ```yaml + admin: + endpoint: dns:///.com + insecure: false + insecureSkipVerify: true + authType: Pkce + proxyCommand: ["flyte-iap", "generate-user-id-token", "--desktop_client_id", ...] # Add this line + ``` + + This configures the Flyte clients to send `"proxy-authorization"` headers with the token generated by the CLI with every request in order to pass the GCP Identity Aware Proxy. + + 6. For registering workflows from CICD, you might have to generate ID tokens for GCP service accounts instead of user accounts. For this purpose, you have the following options: + * `flyte-iap` provides a second sub command called `generate-service-account-id-token`. This subcommand uses either a service account key json file to obtain an ID token or alternatively obtains one from the metadata server when being run on GCP Compute Engine, App Engine, or Cloud Run. It caches tokens and only obtains a new one when the cached token is about to expire. + * If you want to avoid a flytekit/python dependency in your CICD systems, you can use the `gcloud` sdk: + + ``` + gcloud auth print-identity-token --token-format=full --audiences=".apps.googleusercontent.com" + ``` + * Adapt [this bash script](https://cloud.google.com/iap/docs/authentication-howto#obtaining_an_oidc_token_from_a_local_service_account_key_file) from the GCP Identity Aware Proxy documentation which retrieves a token in exchange for service account credentials. (You would need to replace the `curl` command in the last line with `echo $ID_TOKEN`.) + +## Configuring your Flyte deployment to use IAP + +### Introduction + +To protect your Flyte deployment with IAP, we have to deploy it with a GCE ingress (instead of the Nginx ingress used by the default Flyte deployment). + +Flyteadmin has a gRPC endpoint. The gRPC protocol requires the use of http2. When using http2 between a GCP load balancer (created by the GCE ingress) and a backend in GKE, the use of TLS is required ([see documentation](https://cloud.google.com/kubernetes-engine/docs/how-to/ingress-http2)): + +> To ensure the load balancer can make a correct HTTP2 request to your backend, your backend must be configured with SSL. + +The following deployment guide follows [this](https://cloud.google.com/architecture/exposing-service-mesh-apps-through-gke-ingress) reference architecture for the Istio service mesh on Google Kubernetes Engine. + +We will configure an Istio ingress gateway (pod) deployed behind a GCP load balancer to use http2 and TLS (see [here](https://cloud.google.com/architecture/exposing-service-mesh-apps-through-gke-ingress#security)): + +> you can enable HTTP/2 with TLS encryption between the cluster ingress [...] and the mesh ingress (the envoy proxy instance). When you enable HTTP/2 with TLS encryption for this path, you can use a self-signed or public certificate to encrypt traffic [...] + +Flyte is then deployed behind the Istio ingress gateway and does not need to be configured to use TLS itself. + +*Note that we do not do this for security reasons but to enable http2 traffic (required by gRPC) into the cluster through a GCE Ingress (which is required by IAP).* + +### Deployment + +1. If not already done, deploy the flyte-core helm chart, [activating auth](https://docs.flyte.org/en/latest/deployment/configuration/auth_setup.html#apply-oidc-configuration). Re-use the web app client id created for IAP (see section above). Disable the default ingress in the helm values by setting `common.ingress.enabled` to `false` in the helm values file. + + +2. Deployment of Istio and the Istio ingress gateway ([docs](https://istio.io/latest/docs/setup/install/helm/)) + + * `helm repo add istio https://istio-release.storage.googleapis.com/charts` + * `helm repo update` + * `kubectl create namespace istio-system` + * `helm install istio-base istio/base -n istio-system` + * `helm install istiod istio/istiod -n istio-system --wait` + * `helm install istio-ingress istio/gateway -n istio-system -f istio-values.yaml --wait` + + Here, `istio-values.yaml` contains the following: + + ```yaml + service: + annotations: + beta.cloud.google.com/backend-config: '{"default": "ingress-backend-config"}' + cloud.google.com/app-protocols: '{"https": "HTTP2"}' + type: + NodePort + ``` + + It is crucial that the service type is set to `NodePort` and not the default `LoadBalancer`. Otherwise, the Istio ingress gateway won't be deployed behind the GCP load balancer we create below but would be **publicly available on the internet!** + + With the annotations we configured the service to use http2 which is required by gRPC. We also configured the service to use a so-called backend config `ingress-backend-config` which activates IAP and which we will create in the next step. + + +3. Activate IAP for the Istio ingress gateway via a backend config: + + Create a Kubernetes secret containing the web app client id and secret we created above. The creation of the secret is described [here](https://cloud.google.com/iap/docs/enabling-kubernetes-howto#kubernetes-configure). From now on the assumption is that the secret is called `iap-oauth-client-id`. + + Create a backend config for the Istio ingress gateway: + + ```yaml + apiVersion: cloud.google.com/v1 + kind: BackendConfig + metadata: + name: ingress-backend-config + namespace: istio-system + spec: + healthCheck: + port: 15021 + requestPath: /healthz/ready + type: HTTP + iap: + enabled: true + oauthclientCredentials: + secretName: iap-oauth-client-id + ``` + + Note that apart from activating IAP, we also configured a custom health check as the istio ingress gateway doesn't use the default health check path and port assumed by the GCP load balancer. + + +4. [Install Cert Manager](https://cert-manager.io/docs/installation/helm/) to [create and rotate](https://cert-manager.io/docs/configuration/selfsigned/) a self-signed certificate for the istio ingress (pod): + + * `helm repo add jetstack https://charts.jetstack.io` + * `helm repo update` + * `helm install cert-manager jetstack/cert-manager --namespace cert-manager --create-namespace --set installCRDs=true` + + Create the following objects: + + ```yaml + apiVersion: cert-manager.io/v1 + kind: Issuer + metadata: + name: selfsigned-issuer + namespace: istio-system + spec: + selfSigned: {} + ``` + + ```yaml + apiVersion: cert-manager.io/v1 + kind: Certificate + metadata: + name: istio-ingress-cert + namespace: istio-system + spec: + commonName: istio-ingress + dnsNames: + - istio-ingress + - istio-ingress.istio-system.svc + - istio-ingress.istio-system.svc.cluster.local + issuerRef: + kind: Issuer + name: selfsigned-issuer + secretName: istio-ingress-cert + ``` + + This self-signed TLS certificate is only used between the GCP load balancer and the istio ingress gateway. It is not used by the istio ingress gateway to terminate TLS connections from the outside world (as we created it using a `NodePort` type service). Therefore, it is not unsafe to use a self-signed certificate here. Many applications deployed on GKE don't use any additional encryption between the load balancer and the backend. GCP, however, [encrypts these connections by default](https://cloud.google.com/load-balancing/docs/backend-service#encryption_between_the_load_balancer_and_backends): + + > The next hop, which is between the Google Front End (GFE) and the mesh ingress proxy, is encrypted by default. Network-level encryption between the GFEs and their backends is applied automatically. However, if your security requirements dictate that the platform owner retain ownership of the encryption keys, then you can enable HTTP/2 with TLS encryption between the cluster ingress (the GFE) and the mesh ingress (the envoy proxy instance). + + This additional self-managed encryption is also required to use http2 and in extension gRPC. To repeat, we mainly add this self-signed certificate in order to be able to expose a gRPC service (flyteadmin) via a GCP load balancer, less for the additional encryption. + + +5. Configure the istio ingress gateway to use the self-signed certificate: + + + ```yaml + apiVersion: networking.istio.io/v1beta1 + kind: Gateway + metadata: + name: default-gateway + namespace: istio-system + spec: + selector: + app: istio-ingress + istio: ingress + servers: + - hosts: + - '*' + port: + name: https + number: 443 + protocol: HTTPS + tls: + credentialName: istio-ingress-cert + mode: SIMPLE + ``` + + (Note that the `credentialName` matches the `secretName` in the `Certificate` we created.) + + This `Gateway` object configures the Istio ingress gateway (pod) to use the self-signed certificate we created above for every incoming TLS connection. + + +6. Deploy the GCE ingress that will route traffic to the istio ingress gateway: + + + * Create a global (not regional) static IP address in GCP as is described [here](https://cloud.google.com/kubernetes-engine/docs/how-to/managed-certs#prerequisites). + * Create a DNS record for your Flyte domain to route traffic to this static IP address. + * Create a GCP managed certificate (please fill in your domain): + + ```yaml + apiVersion: networking.gke.io/v1 + kind: ManagedCertificate + metadata: + name: flyte-managed-certificate + namespace: istio-system + spec: + domains: + - < fill in your domain > + ``` + * Create the ingress (please fill in the name of the static IP): + + ```yaml + apiVersion: networking.k8s.io/v1 + kind: Ingress + metadata: + annotations: + kubernetes.io/ingress.allow-http: "true" + kubernetes.io/ingress.global-static-ip-name: "< fill in >" + networking.gke.io/managed-certificates: flyte-managed-certificate + networking.gke.io/v1beta1.FrontendConfig: ingress-frontend-config + name: flyte-ingress + namespace: istio-system + spec: + rules: + - http: + paths: + - backend: + service: + name: istio-ingress + port: + number: 443 + path: / + pathType: Prefix + --- + apiVersion: networking.gke.io/v1beta1 + kind: FrontendConfig + metadata: + name: ingress-frontend-config + namespace: istio-system + spec: + redirectToHttps: + enabled: true + responseCodeName: MOVED_PERMANENTLY_DEFAULT + ``` + + This ingress routes all traffic to the istio ingress gateway via http2 and TLS. + + For clarity: The GCP load balancer TLS terminates connections coming from the outside world using a GCP managed certificate. + The self-signed certificate created above is only used between the GCP load balancer and the istio ingress gateway running in the cluster. + To repeat, because of this it is important for security that the istio ingress gateway uses a `NodePort` type service and not a `LoadBalancer`. + + * In the GCP cloud console under *Kubernetes Engine/Services & Ingress/Ingress* (selecting the respective cluster and the `istio-system` namespace), you can observe the status of the ingress, its managed certificate, and its backends. Only proceed if all statuses are green. The creation of the GCP load balancer configured by the ingress and of the managed certificate can take up to 30 minutes during the first deployment. + + +7. Connect flyteadmin and flyteconsole to the istio ingress gateway: + + So far, we created a GCE ingress (which creates a GCP load balancer). The load balancer is configured to forward all requests to the istio ingress gatway at the edge of the service mesh via http2 and TLS. + + Next, we configure the Istio service mesh to route requests from the Istio ingress gateway to flyteadmin and flyteconsole. + + In istio, this is configured using a so-called `VirtualService` object. + + Please fill in your flyte domain in the following manifest and apply it to the cluster: + + ```yaml + apiVersion: networking.istio.io/v1beta1 + kind: VirtualService + metadata: + name: flyte-virtualservice + namespace: flyte + spec: + gateways: + - istio-system/default-gateway + hosts: + - + http: + - match: + - uri: + prefix: /console + name: console-routes + route: + - destination: + host: flyteconsole + port: + number: 80 + - match: + - uri: + prefix: /api + - uri: + prefix: /healthcheck + - uri: + prefix: /v1/* + - uri: + prefix: /.well-known + - uri: + prefix: /login + - uri: + prefix: /logout + - uri: + prefix: /callback + - uri: + prefix: /me + - uri: + prefix: /config + - uri: + prefix: /oauth2 + name: admin-routes + route: + - destination: + host: flyteadmin + port: + number: 80 + - match: + - uri: + prefix: /flyteidl.service.SignalService + - uri: + prefix: /flyteidl.service.AdminService + - uri: + prefix: /flyteidl.service.DataProxyService + - uri: + prefix: /flyteidl.service.AuthMetadataService + - uri: + prefix: /flyteidl.service.IdentityService + - uri: + prefix: /grpc.health.v1.Health + name: admin-grpc-routes + route: + - destination: + host: flyteadmin + port: + number: 81 + ``` + + In this `VirtualService`, the routing rules for flyteadmin and flyteconsole are configured which in Flyte's default deployment are configured in the Nginx ingress. + + Note that the virtual service references the `Gateway` object we created above which configures the istio ingress gateway to use TLS for these connections. + +8. Test your Flyte deployment with IAP by e.g. executing this python script: + + ```python + from flytekit.remote import FlyteRemote + + from flytekit.configuration import Config + + + remote = FlyteRemote( + config=Config.auto(), + default_project="flytesnacks", + default_domain="development", + ) + + + print(remote.recent_executions()) + ``` + + A browser window should open and ask you to login with your Google account. You should then see confirmation that you *"Successfully logged into accounts.google.com"* (this was for the IAP), finally followd by confirmation that you *"Successfully logged into 'your flyte domain'"* (this was for Flyte itself). + + + +9. At this point your Flyte deployment should be successfully protected by a GCP identity aware proxy using a zero trust model. + + You should check in the GCP cloud console's *IAP* page that IAP is actually activated and configured correctly for the Istio ingress gateway (follow up on any yellow or red status symbols next to the respective backend). + + You could also open the flyte console in an incognito browser window and verify that you are asked to login with your Google account. + + Finally, you could also comment out the `proxyCommand` line in your `~/.flyte/config.yaml` and verify that you are no longer able to access your Flyte deployment behind IAP. + +10. The double login observed above is due to the fact that the Flyte clients send `"proxy-authorization"` headers generated by the CLI provided by this plugin with every request in order to make it past IAP. They still also send the regular `"authorization"` header issued by flyteadmin itself. + + Since the refresh token for Flyte and the one for IAP by default don't have the same lifespan, you likely won't notice this double login again. However, since your deployment is already protected by IAP, the ID token (issued by flyteadmin) in the `"authorization"` header mostly serves to identify users. Therefore, you can consider to increase the lifespan of the refresh token issued by flyteadmin to e.g. 7 days by setting `configmap.adminServer.auth.appAuth.selfAuthServer.refreshTokenLifespan` to e.g. `168h0m0s` in your Flyte helm values file. This way, your users should barely notice the double login. diff --git a/plugins/flytekit-identity-aware-proxy/flytekitplugins/identity_aware_proxy/__init__.py b/plugins/flytekit-identity-aware-proxy/flytekitplugins/identity_aware_proxy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-identity-aware-proxy/flytekitplugins/identity_aware_proxy/cli.py b/plugins/flytekit-identity-aware-proxy/flytekitplugins/identity_aware_proxy/cli.py new file mode 100644 index 0000000000..3c70429848 --- /dev/null +++ b/plugins/flytekit-identity-aware-proxy/flytekitplugins/identity_aware_proxy/cli.py @@ -0,0 +1,248 @@ +import logging +import os +import typing + +import click +import jwt +from google.api_core.exceptions import NotFound +from google.auth import default +from google.auth.transport.requests import Request +from google.cloud import secretmanager +from google.oauth2 import id_token + +from flytekit.clients.auth.auth_client import AuthorizationClient +from flytekit.clients.auth.authenticator import Authenticator +from flytekit.clients.auth.exceptions import AccessTokenNotFoundError +from flytekit.clients.auth.keyring import Credentials, KeyringStore + +WEBAPP_CLIENT_ID_HELP = ( + "Webapp type OAuth 2.0 client ID used by the IAP. " + "Typically in the form of `.apps.googleusercontent.com`. " + "Created when activating IAP for the Flyte deployment. " + "https://cloud.google.com/iap/docs/enabling-kubernetes-howto#oauth-credentials" +) + + +class GCPIdentityAwareProxyAuthenticator(Authenticator): + """ + This Authenticator encapsulates the entire OAauth 2.0 flow with GCP Identity Aware Proxy. + + The auth flow is described in https://cloud.google.com/iap/docs/authentication-howto#signing_in_to_the_application + + Automatically opens a browser window for login. + """ + + def __init__( + self, + audience: str, + client_id: str, + client_secret: str, + verify: typing.Optional[typing.Union[bool, str]] = None, + ): + """ + Initialize with default creds from KeyStore using the audience name. + """ + super().__init__(audience, "proxy-authorization", KeyringStore.retrieve(audience), verify=verify) + self._auth_client = None + + self.audience = audience + self.client_id = client_id + self.client_secret = client_secret + self.redirect_uri = "http://localhost:4444" + + def _initialize_auth_client(self): + if not self._auth_client: + self._auth_client = AuthorizationClient( + endpoint=self.audience, + # See step 3 in https://cloud.google.com/iap/docs/authentication-howto#signing_in_to_the_application + auth_endpoint="https://accounts.google.com/o/oauth2/v2/auth", + token_endpoint="https://oauth2.googleapis.com/token", + # See step 3 in https://cloud.google.com/iap/docs/authentication-howto#signing_in_to_the_application + scopes=["openid", "email"], + client_id=self.client_id, + redirect_uri=self.redirect_uri, + verify=self._verify, + # See step 3 in https://cloud.google.com/iap/docs/authentication-howto#signing_in_to_the_application + request_auth_code_params={ + "cred_ref": "true", + "access_type": "offline", + }, + # See step 4 in https://cloud.google.com/iap/docs/authentication-howto#signing_in_to_the_application + request_access_token_params={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "audience": self.audience, + "redirect_uri": self.redirect_uri, + }, + # See https://cloud.google.com/iap/docs/authentication-howto#refresh_token + refresh_access_token_params={ + "client_secret": self.client_secret, + "audience": self.audience, + }, + ) + + def refresh_credentials(self): + """Refresh the IAP credentials. If no credentials are found, it will kick off a full OAuth 2.0 authorization flow.""" + self._initialize_auth_client() + if self._creds: + """We have an id token so lets try to refresh it""" + try: + self._creds = self._auth_client.refresh_access_token(self._creds) + if self._creds: + KeyringStore.store(self._creds) + return + except AccessTokenNotFoundError: + logging.warning("Failed to refresh token. Kicking off a full authorization flow.") + KeyringStore.delete(self._endpoint) + + self._creds = self._auth_client.get_creds_from_remote() + KeyringStore.store(self._creds) + + +def get_gcp_secret_manager_secret(project_id: str, secret_id: str, version: typing.Optional[str] = "latest"): + """Retrieve secret from GCP secret manager.""" + client = secretmanager.SecretManagerServiceClient() + name = f"projects/{project_id}/secrets/{secret_id}/versions/{version}" + try: + response = client.access_secret_version(name=name) + except NotFound as e: + raise click.BadParameter(e.message) + payload = response.payload.data.decode("UTF-8") + return payload + + +@click.group() +def cli(): + """Generate ID tokens for GCP Identity Aware Proxy (IAP).""" + pass + + +@cli.command() +@click.option( + "--desktop_client_id", + type=str, + default=None, + required=True, + help=( + "Desktop type OAuth 2.0 client ID. Typically in the form of `.apps.googleusercontent.com`. " + "Create by following https://cloud.google.com/iap/docs/authentication-howto#setting_up_the_client_id" + ), +) +@click.option( + "--desktop_client_secret_gcp_secret_name", + type=str, + default=None, + required=True, + help=( + "Name of a GCP secret manager secret containing the desktop type OAuth 2.0 client secret " + "obtained together with desktop type OAuth 2.0 client ID." + ), +) +@click.option( + "--webapp_client_id", + type=str, + default=None, + required=True, + help=WEBAPP_CLIENT_ID_HELP, +) +@click.option( + "--project", + type=str, + default=None, + required=True, + help="GCP project ID (in which `desktop_client_secret_gcp_secret_name` is saved).", +) +def generate_user_id_token( + desktop_client_id: str, desktop_client_secret_gcp_secret_name: str, webapp_client_id: str, project: str +): + """Generate a user account ID token for proxy-authorization with GCP Identity Aware Proxy.""" + desktop_client_secret = get_gcp_secret_manager_secret(project, desktop_client_secret_gcp_secret_name) + + iap_authenticator = GCPIdentityAwareProxyAuthenticator( + audience=webapp_client_id, + client_id=desktop_client_id, + client_secret=desktop_client_secret, + ) + try: + iap_authenticator.refresh_credentials() + except Exception as e: + raise click.ClickException(f"Failed to obtain credentials for GCP Identity Aware Proxy (IAP): {e}") + + click.echo(iap_authenticator.get_credentials().id_token) + + +def get_service_account_id_token(audience: str, service_account_email: str) -> str: + """Fetch an ID Token for the service account used by the current environment. + + Uses flytekit's KeyringStore to cache the ID token. + + This function acquires ID token from the environment in the following order. + See https://google.aip.dev/auth/4110. + + 1. If the environment variable ``GOOGLE_APPLICATION_CREDENTIALS`` is set + to the path of a valid service account JSON file, then ID token is + acquired using this service account credentials. + 2. If the application is running in Compute Engine, App Engine or Cloud Run, + then the ID token are obtained from the metadata server. + + Args: + audience (str): The audience that this ID token is intended for. + service_account_email (str): The email address of the service account. + """ + # Flytekit's KeyringStore, by default, uses the endpoint as the key to store the credentials + # We use the audience and the service account email as the key + audience_and_account_key = audience + "-" + service_account_email + creds = KeyringStore.retrieve(audience_and_account_key) + if creds: + is_expired = False + try: + exp_margin = -300 # Generate a new token if it expires in less than 5 minutes + jwt.decode( + creds.id_token.encode("utf-8"), + options={"verify_signature": False, "verify_exp": True}, + leeway=exp_margin, + ) + except jwt.ExpiredSignatureError: + is_expired = True + + if not is_expired: + return creds.id_token + + token = id_token.fetch_id_token(Request(), audience) + + KeyringStore.store(Credentials(for_endpoint=audience_and_account_key, access_token="", id_token=token)) + return token + + +@cli.command() +@click.option( + "--webapp_client_id", + type=str, + default=None, + required=True, + help=WEBAPP_CLIENT_ID_HELP, +) +@click.option( + "--service_account_key", + type=click.Path(exists=True, dir_okay=False), + default=None, + required=False, + help=( + "Path to a service account key file. Alternatively set the environment variable " + "`GOOGLE_APPLICATION_CREDENTIALS` to the path of the service account key file. " + "If not provided and in Compute Engine, App Engine, or Cloud Run, will retrieve " + "the ID token from the metadata server." + ), +) +def generate_service_account_id_token(webapp_client_id: str, service_account_key: str): + """Generate a service account ID token for proxy-authorization with GCP Identity Aware Proxy.""" + if service_account_key: + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = service_account_key + + application_default_credentials, _ = default() + token = get_service_account_id_token(webapp_client_id, application_default_credentials.service_account_email) + click.echo(token) + + +if __name__ == "__main__": + cli() diff --git a/plugins/flytekit-identity-aware-proxy/setup.py b/plugins/flytekit-identity-aware-proxy/setup.py new file mode 100644 index 0000000000..33f8af248d --- /dev/null +++ b/plugins/flytekit-identity-aware-proxy/setup.py @@ -0,0 +1,43 @@ +from setuptools import setup + +PLUGIN_NAME = "identity_aware_proxy" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["click", "google-cloud-secret-manager", "google-auth", "flytekit"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="External command plugin to generate ID tokens for GCP Identity Aware Proxy", + url="https://github.com/flyteorg/flytekit/tree/master/plugins/flytekit-identity-aware-proxy", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + entry_points={ + "console_scripts": [ + "flyte-iap=flytekitplugins.identity_aware_proxy.cli:cli", + ], + }, + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.8", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/flytekit-identity-aware-proxy/tests/__init__.py b/plugins/flytekit-identity-aware-proxy/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-identity-aware-proxy/tests/test_flytekitplugins_iap.py b/plugins/flytekit-identity-aware-proxy/tests/test_flytekitplugins_iap.py new file mode 100644 index 0000000000..766ff646ab --- /dev/null +++ b/plugins/flytekit-identity-aware-proxy/tests/test_flytekitplugins_iap.py @@ -0,0 +1,146 @@ +import uuid +from datetime import datetime, timedelta +from unittest.mock import MagicMock, patch + +import click +import jwt +import pytest +from click.testing import CliRunner +from flytekitplugins.identity_aware_proxy.cli import cli, get_gcp_secret_manager_secret, get_service_account_id_token +from google.api_core.exceptions import NotFound + + +def test_help() -> None: + """Smoke test external command IAP ID token generator cli by printing help message.""" + runner = CliRunner() + result = runner.invoke(cli, "--help") + assert "Generate ID tokens" in result.output + assert result.exit_code == 0 + + result = runner.invoke(cli, ["generate-user-id-token", "--help"]) + assert "Generate a user account ID token" in result.output + assert result.exit_code == 0 + + result = runner.invoke(cli, ["generate-service-account-id-token", "--help"]) + assert "Generate a service account ID token" in result.output + assert result.exit_code == 0 + + +def test_get_gcp_secret_manager_secret(): + """Test retrieval of GCP secret manager secret.""" + project_id = "test_project" + secret_id = "test_secret" + version = "latest" + expected_payload = "test_payload" + + mock_client = MagicMock() + mock_client.access_secret_version.return_value.payload.data.decode.return_value = expected_payload + with patch("google.cloud.secretmanager.SecretManagerServiceClient", return_value=mock_client): + payload = get_gcp_secret_manager_secret(project_id, secret_id, version) + assert payload == expected_payload + + name = f"projects/{project_id}/secrets/{secret_id}/versions/{version}" + mock_client.access_secret_version.assert_called_once_with(name=name) + + +def test_get_gcp_secret_manager_secret_not_found(): + """Test retrieving non-existing secret from GCP secret manager.""" + project_id = "test_project" + secret_id = "test_secret" + version = "latest" + + mock_client = MagicMock() + mock_client.access_secret_version.side_effect = NotFound("Secret not found") + with patch("google.cloud.secretmanager.SecretManagerServiceClient", return_value=mock_client): + with pytest.raises(click.BadParameter): + get_gcp_secret_manager_secret(project_id, secret_id, version) + + +def create_mock_token(aud: str, expires_in: timedelta = None): + """Create a mock JWT token with a certain audience, expiration time, and random JTI.""" + exp = datetime.utcnow() + expires_in + jti = "test_token" + str(uuid.uuid4()) + payload = {"exp": exp, "aud": aud, "jti": jti} + + secret = "your-secret-key" + algorithm = "HS256" + + return jwt.encode(payload, secret, algorithm=algorithm) + + +@patch("flytekitplugins.identity_aware_proxy.cli.id_token.fetch_id_token") +@patch("keyring.get_password") +@patch("keyring.set_password") +def test_sa_id_token_no_token_in_keyring(kr_set_password, kr_get_password, mock_fetch_id_token): + """Test retrieval and caching of service account ID token when no token is stored in keyring yet.""" + test_audience = "test_audience" + service_account_email = "default" + + # Start with a clean KeyringStore + tmp_test_keyring_store = {} + kr_get_password.side_effect = lambda service, user: tmp_test_keyring_store.get(service, {}).get(user, None) + kr_set_password.side_effect = lambda service, user, pwd: tmp_test_keyring_store.update({service: {user: pwd}}) + + mock_fetch_id_token.side_effect = lambda _, aud: create_mock_token(aud, expires_in=timedelta(hours=1)) + + token = get_service_account_id_token(test_audience, service_account_email) + + assert jwt.decode(token.encode("utf-8"), options={"verify_signature": False})["aud"] == test_audience + assert jwt.decode(token.encode("utf-8"), options={"verify_signature": False})["jti"].startswith("test_token") + + # Check that the token is cached in the KeyringStore + second_token = get_service_account_id_token(test_audience, service_account_email) + + assert token == second_token + + +@patch("flytekitplugins.identity_aware_proxy.cli.id_token.fetch_id_token") +@patch("keyring.get_password") +@patch("keyring.set_password") +def test_sa_id_token_expired_token_in_keyring(kr_set_password, kr_get_password, mock_fetch_id_token): + """Test that expired service account ID token in keyring is replaced with a new one.""" + test_audience = "test_audience" + service_account_email = "default" + + # Start with an expired token in the KeyringStore + expired_id_token = create_mock_token(test_audience, expires_in=timedelta(hours=-1)) + tmp_test_keyring_store = {test_audience + "-" + service_account_email: {"id_token": expired_id_token}} + kr_get_password.side_effect = lambda service, user: tmp_test_keyring_store.get(service, {}).get(user, None) + kr_set_password.side_effect = lambda service, user, pwd: tmp_test_keyring_store.update({service: {user: pwd}}) + + mock_fetch_id_token.side_effect = lambda _, aud: create_mock_token(aud, expires_in=timedelta(hours=1)) + + token = get_service_account_id_token(test_audience, service_account_email) + + assert token != expired_id_token + assert jwt.decode(token.encode("utf-8"), options={"verify_signature": False})["aud"] == test_audience + assert jwt.decode(token.encode("utf-8"), options={"verify_signature": False})["jti"].startswith("test_token") + + +@patch("flytekitplugins.identity_aware_proxy.cli.id_token.fetch_id_token") +@patch("keyring.get_password") +@patch("keyring.set_password") +def test_sa_id_token_switch_accounts(kr_set_password, kr_get_password, mock_fetch_id_token): + """Test that caching works when switching service accounts.""" + test_audience = "test_audience" + service_account_email = "default" + service_account_other_email = "other" + + # Start with a clean KeyringStore + tmp_test_keyring_store = {} + kr_get_password.side_effect = lambda service, user: tmp_test_keyring_store.get(service, {}).get(user, None) + kr_set_password.side_effect = lambda service, user, pwd: tmp_test_keyring_store.update({service: {user: pwd}}) + + mock_fetch_id_token.side_effect = lambda _, aud: create_mock_token(aud, expires_in=timedelta(hours=1)) + + default_token = get_service_account_id_token(test_audience, service_account_email) + other_token = get_service_account_id_token(test_audience, service_account_other_email) + + assert default_token != other_token + + # Check that the tokens are cached in the KeyringStore + new_default_token = get_service_account_id_token(test_audience, service_account_email) + new_other_token = get_service_account_id_token(test_audience, service_account_other_email) + + assert default_token == new_default_token + assert other_token == new_other_token diff --git a/tests/flytekit/unit/clients/auth/test_authenticator.py b/tests/flytekit/unit/clients/auth/test_authenticator.py index 32709e1eaa..82ffa654dd 100644 --- a/tests/flytekit/unit/clients/auth/test_authenticator.py +++ b/tests/flytekit/unit/clients/auth/test_authenticator.py @@ -67,8 +67,15 @@ def test_command_authenticator(mock_subprocess: MagicMock): authn.refresh_credentials() -@patch("flytekit.clients.auth.token_client.requests") -def test_client_creds_authenticator(mock_requests): +@patch("flytekit.clients.auth.token_client.requests.Session") +def test_client_creds_authenticator(mock_session): + session = MagicMock() + response = MagicMock() + response.status_code = 200 + response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") + session.post.return_value = response + mock_session.return_value = session + authn = ClientCredentialsAuthenticator( ENDPOINT, client_id="client", @@ -77,13 +84,11 @@ def test_client_creds_authenticator(mock_requests): http_proxy_url="https://my-proxy:31111", ) - response = MagicMock() - response.status_code = 200 - response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") - mock_requests.post.return_value = response authn.refresh_credentials() expected_scopes = static_cfg_store.get_client_config().scopes + assert authn._creds + assert authn._creds.access_token == "abc" assert authn._scopes == expected_scopes @@ -113,9 +118,17 @@ def test_device_flow_authenticator(poll_mock: MagicMock, device_mock: MagicMock, assert authn._creds -@patch("flytekit.clients.auth.token_client.requests") -def test_client_creds_authenticator_with_custom_scopes(mock_requests): +@patch("flytekit.clients.auth.token_client.requests.Session") +def test_client_creds_authenticator_with_custom_scopes(mock_session): expected_scopes = ["foo", "baz"] + + session = MagicMock() + response = MagicMock() + response.status_code = 200 + response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") + session.post.return_value = response + mock_session.return_value = session + authn = ClientCredentialsAuthenticator( ENDPOINT, client_id="client", @@ -124,11 +137,9 @@ def test_client_creds_authenticator_with_custom_scopes(mock_requests): scopes=expected_scopes, verify=True, ) - response = MagicMock() - response.status_code = 200 - response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") - mock_requests.post.return_value = response + authn.refresh_credentials() assert authn._creds + assert authn._creds.access_token == "abc" assert authn._scopes == expected_scopes diff --git a/tests/flytekit/unit/clients/auth/test_token_client.py b/tests/flytekit/unit/clients/auth/test_token_client.py index d0e75ec88a..a7e9c9c280 100644 --- a/tests/flytekit/unit/clients/auth/test_token_client.py +++ b/tests/flytekit/unit/clients/auth/test_token_client.py @@ -22,12 +22,14 @@ def test_get_basic_authorization_header(): assert header == "Basic Y2xpZW50X2lkOmFiYyUyNSUyNSUyNCUzRiU1QyUyRiU1QyUyRg==" -@patch("flytekit.clients.auth.token_client.requests") -def test_get_token(mock_requests): +@patch("flytekit.clients.auth.token_client.requests.Session") +def test_get_token(mock_session): + session = MagicMock() response = MagicMock() response.status_code = 200 response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") - mock_requests.post.return_value = response + session.post.return_value = response + mock_session.return_value = session access, expiration = get_token( "https://corp.idp.net", client_id="abc123", scopes=["my_scope"], http_proxy_url="http://proxy:3000", verify=True ) @@ -35,11 +37,13 @@ def test_get_token(mock_requests): assert expiration == 60 -@patch("flytekit.clients.auth.token_client.requests") -def test_get_device_code(mock_requests): +@patch("flytekit.clients.auth.token_client.requests.Session") +def test_get_device_code(mock_session): + session = MagicMock() response = MagicMock() response.ok = False - mock_requests.post.return_value = response + session.post.return_value = response + mock_session.return_value = session with pytest.raises(AuthenticationError): get_device_code("test.com", "test", http_proxy_url="http://proxy:3000") @@ -51,18 +55,21 @@ def test_get_device_code(mock_requests): "expires_in": 600, "interval": 5, } - mock_requests.post.return_value = response + session.post.return_value = response c = get_device_code("test.com", "test", http_proxy_url="http://proxy:3000") assert c assert c.device_code == "code" -@patch("flytekit.clients.auth.token_client.requests") -def test_poll_token_endpoint(mock_requests): +@patch("flytekit.clients.auth.token_client.requests.Session") +def test_poll_token_endpoint(mock_session): + session = MagicMock() response = MagicMock() response.ok = False response.json.return_value = {"error": error_auth_pending} - mock_requests.post.return_value = response + + session.post.return_value = response + mock_session.return_value = session r = DeviceCodeResponse(device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=1) with pytest.raises(AuthenticationError): @@ -71,8 +78,9 @@ def test_poll_token_endpoint(mock_requests): response = MagicMock() response.ok = True response.json.return_value = {"access_token": "abc", "expires_in": 60} - mock_requests.post.return_value = response + session.post.return_value = response r = DeviceCodeResponse(device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=0) t, e = poll_token_endpoint(r, "test.com", "test", http_proxy_url="http://proxy:3000", verify=True) - assert t - assert e + + assert t == "abc" + assert e == 60 diff --git a/tests/flytekit/unit/clients/test_auth_helper.py b/tests/flytekit/unit/clients/test_auth_helper.py index 3bd57918f4..9578f81b3e 100644 --- a/tests/flytekit/unit/clients/test_auth_helper.py +++ b/tests/flytekit/unit/clients/test_auth_helper.py @@ -1,7 +1,9 @@ import os.path +from http import HTTPStatus from unittest.mock import MagicMock, patch import pytest +import requests from flyteidl.service.auth_pb2 import OAuth2MetadataResponse, PublicClientAuthConfigResponse from flytekit.clients.auth.authenticator import ( @@ -16,8 +18,10 @@ from flytekit.clients.auth_helper import ( RemoteClientConfigStore, get_authenticator, + get_session, load_cert, upgrade_channel_to_authenticated, + upgrade_channel_to_proxy_authenticated, wrap_exceptions_channel, ) from flytekit.clients.grpc_utils.auth_interceptor import AuthUnaryInterceptor @@ -76,7 +80,7 @@ def get_client_config(**kwargs) -> ClientConfigStore: authorization_endpoint=OAUTH_AUTHORIZE, redirect_uri=REDIRECT_URI, client_id=CLIENT_ID, - **kwargs + **kwargs, ) return cfg_store @@ -160,8 +164,45 @@ def test_upgrade_channel_to_auth(): assert isinstance(out_ch._interceptor, AuthUnaryInterceptor) # noqa +def test_upgrade_channel_to_proxy_auth(): + ch = MagicMock() + out_ch = upgrade_channel_to_proxy_authenticated( + PlatformConfig( + auth_mode="Pkce", + proxy_command=["echo", "foo-bar"], + ), + ch, + ) + assert isinstance(out_ch._interceptor, AuthUnaryInterceptor) + assert isinstance(out_ch._interceptor._authenticator, CommandAuthenticator) + + def test_load_cert(): cert_file = os.path.join(os.path.dirname(__file__), "testdata", "rootCACert.pem") f = load_cert(cert_file) assert f print(f) + + +def test_get_proxy_authenticated_session(): + """Test that proxy auth headers are added to http requests if the proxy command is provided in the platform config.""" + expected_token = "foo-bar" + platform_config = PlatformConfig( + endpoint="http://my-flyte-deployment.com", + proxy_command=["echo", expected_token], + ) + + with patch("requests.adapters.HTTPAdapter.send") as mock_send: + mock_response = requests.Response() + mock_response.status_code = HTTPStatus.UNAUTHORIZED + mock_response._content = b"{}" + mock_send.return_value = mock_response + + session = get_session(platform_config) + request = requests.Request("GET", platform_config.endpoint) + prepared_request = session.prepare_request(request) + + # Send the request to trigger the addition of the proxy auth headers + session.send(prepared_request) + + assert prepared_request.headers["proxy-authorization"] == f"Bearer {expected_token}"