diff --git a/flytekit/clients/auth/authenticator.py b/flytekit/clients/auth/authenticator.py index b1cf19647d..b2b82831c7 100644 --- a/flytekit/clients/auth/authenticator.py +++ b/flytekit/clients/auth/authenticator.py @@ -25,6 +25,7 @@ class ClientConfig: device_authorization_endpoint: typing.Optional[str] = None scopes: typing.List[str] = None header_key: str = "authorization" + audience: typing.Optional[str] = None class ClientConfigStore(object): @@ -174,6 +175,7 @@ def __init__( scopes: typing.Optional[typing.List[str]] = None, http_proxy_url: typing.Optional[str] = None, verify: typing.Optional[typing.Union[bool, str]] = None, + audience: typing.Optional[str] = None, ): if not client_id or not client_secret: raise ValueError("Client ID and Client SECRET both are required.") @@ -183,6 +185,7 @@ def __init__( self._scopes = scopes or cfg.scopes self._client_id = client_id self._client_secret = client_secret + self._audience = audience or cfg.audience super().__init__(endpoint, cfg.header_key or header_key, http_proxy_url=http_proxy_url, verify=verify) def refresh_credentials(self): @@ -195,14 +198,21 @@ def refresh_credentials(self): """ token_endpoint = self._token_endpoint scopes = self._scopes + audience = self._audience # Note that unlike the Pkce flow, the client ID does not come from Admin. logging.debug(f"Basic authorization flow with client id {self._client_id} scope {scopes}") authorization_header = token_client.get_basic_authorization_header(self._client_id, self._client_secret) token, expires_in = token_client.get_token( - token_endpoint, scopes, authorization_header, http_proxy_url=self._http_proxy_url, verify=self._verify + token_endpoint=token_endpoint, + authorization_header=authorization_header, + http_proxy_url=self._http_proxy_url, + verify=self._verify, + scopes=scopes, + audience=audience, ) + logging.info("Retrieved new token, expires in {}".format(expires_in)) self._creds = Credentials(token) diff --git a/flytekit/clients/auth/token_client.py b/flytekit/clients/auth/token_client.py index 2e14fe8afc..e5eae32ed7 100644 --- a/flytekit/clients/auth/token_client.py +++ b/flytekit/clients/auth/token_client.py @@ -74,6 +74,7 @@ def get_token( authorization_header: typing.Optional[str] = None, client_id: typing.Optional[str] = None, device_code: typing.Optional[str] = None, + audience: typing.Optional[str] = None, grant_type: GrantType = GrantType.CLIENT_CREDS, http_proxy_url: typing.Optional[str] = None, verify: typing.Optional[typing.Union[bool, str]] = None, @@ -98,9 +99,12 @@ def get_token( body["device_code"] = device_code if scopes is not None: body["scope"] = ",".join(scopes) + if audience: + body["audience"] = audience proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None response = requests.post(token_endpoint, data=body, headers=headers, proxies=proxies, verify=verify) + if not response.ok: j = response.json() if "error" in j: diff --git a/flytekit/clients/auth_helper.py b/flytekit/clients/auth_helper.py index ce2992723f..5c4fafe579 100644 --- a/flytekit/clients/auth_helper.py +++ b/flytekit/clients/auth_helper.py @@ -43,6 +43,7 @@ def get_client_config(self) -> ClientConfig: scopes=public_client_config.scopes, header_key=public_client_config.authorization_metadata_key or None, device_authorization_endpoint=oauth2_metadata.device_authorization_endpoint, + audience=public_client_config.audience, ) @@ -73,6 +74,7 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth client_secret=cfg.client_credentials_secret, cfg_store=cfg_store, scopes=cfg.scopes, + audience=cfg.audience, http_proxy_url=cfg.http_proxy_url, verify=verify, ) diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index a7e2c69ebd..e31af5f389 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -428,6 +428,7 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None kwargs = set_if_exists(kwargs, "auth_mode", _internal.Credentials.AUTH_MODE.read(config_file)) kwargs = set_if_exists(kwargs, "endpoint", _internal.Platform.URL.read(config_file)) kwargs = set_if_exists(kwargs, "console_endpoint", _internal.Platform.CONSOLE_ENDPOINT.read(config_file)) + kwargs = set_if_exists(kwargs, "http_proxy_url", _internal.Platform.HTTP_PROXY_URL.read(config_file)) return PlatformConfig(**kwargs)