Skip to content

Commit

Permalink
feat: Add Auth0/audience support for ClientCredentials flow (#1639)
Browse files Browse the repository at this point in the history
* feat: Add Auth0/audience support for ClientCredentials flow

Signed-off-by: tnam <[email protected]>

* refactor: Remove unneeded variables & condense code

Signed-off-by: tnam <[email protected]>

* refactor: Reduce verbosity of code

Signed-off-by: tnam <[email protected]>

* refactor(chore): Remove unused commented code

Signed-off-by: tnam <[email protected]>

* fix: Missing comma in input args - authenticator.py 213

Signed-off-by: tnam <[email protected]>

* style: Run pre-commit against all files

Signed-off-by: tnam <[email protected]>

---------

Signed-off-by: tnam <[email protected]>
Co-authored-by: tnam <[email protected]>
  • Loading branch information
PudgyPigeon and tnam authored Jun 9, 2023
1 parent 8437023 commit 3370a96
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 1 deletion.
12 changes: 11 additions & 1 deletion flytekit/clients/auth/authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class ClientConfig:
device_authorization_endpoint: typing.Optional[str] = None
scopes: typing.List[str] = None
header_key: str = "authorization"
audience: typing.Optional[str] = None


class ClientConfigStore(object):
Expand Down Expand Up @@ -174,6 +175,7 @@ def __init__(
scopes: typing.Optional[typing.List[str]] = None,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
audience: typing.Optional[str] = None,
):
if not client_id or not client_secret:
raise ValueError("Client ID and Client SECRET both are required.")
Expand All @@ -183,6 +185,7 @@ def __init__(
self._scopes = scopes or cfg.scopes
self._client_id = client_id
self._client_secret = client_secret
self._audience = audience or cfg.audience
super().__init__(endpoint, cfg.header_key or header_key, http_proxy_url=http_proxy_url, verify=verify)

def refresh_credentials(self):
Expand All @@ -195,14 +198,21 @@ def refresh_credentials(self):
"""
token_endpoint = self._token_endpoint
scopes = self._scopes
audience = self._audience

# Note that unlike the Pkce flow, the client ID does not come from Admin.
logging.debug(f"Basic authorization flow with client id {self._client_id} scope {scopes}")
authorization_header = token_client.get_basic_authorization_header(self._client_id, self._client_secret)

token, expires_in = token_client.get_token(
token_endpoint, scopes, authorization_header, http_proxy_url=self._http_proxy_url, verify=self._verify
token_endpoint=token_endpoint,
authorization_header=authorization_header,
http_proxy_url=self._http_proxy_url,
verify=self._verify,
scopes=scopes,
audience=audience,
)

logging.info("Retrieved new token, expires in {}".format(expires_in))
self._creds = Credentials(token)

Expand Down
4 changes: 4 additions & 0 deletions flytekit/clients/auth/token_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def get_token(
authorization_header: typing.Optional[str] = None,
client_id: typing.Optional[str] = None,
device_code: typing.Optional[str] = None,
audience: typing.Optional[str] = None,
grant_type: GrantType = GrantType.CLIENT_CREDS,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
Expand All @@ -98,9 +99,12 @@ def get_token(
body["device_code"] = device_code
if scopes is not None:
body["scope"] = ",".join(scopes)
if audience:
body["audience"] = audience

proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None
response = requests.post(token_endpoint, data=body, headers=headers, proxies=proxies, verify=verify)

if not response.ok:
j = response.json()
if "error" in j:
Expand Down
2 changes: 2 additions & 0 deletions flytekit/clients/auth_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def get_client_config(self) -> ClientConfig:
scopes=public_client_config.scopes,
header_key=public_client_config.authorization_metadata_key or None,
device_authorization_endpoint=oauth2_metadata.device_authorization_endpoint,
audience=public_client_config.audience,
)


Expand Down Expand Up @@ -73,6 +74,7 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth
client_secret=cfg.client_credentials_secret,
cfg_store=cfg_store,
scopes=cfg.scopes,
audience=cfg.audience,
http_proxy_url=cfg.http_proxy_url,
verify=verify,
)
Expand Down
1 change: 1 addition & 0 deletions flytekit/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None
kwargs = set_if_exists(kwargs, "auth_mode", _internal.Credentials.AUTH_MODE.read(config_file))
kwargs = set_if_exists(kwargs, "endpoint", _internal.Platform.URL.read(config_file))
kwargs = set_if_exists(kwargs, "console_endpoint", _internal.Platform.CONSOLE_ENDPOINT.read(config_file))

kwargs = set_if_exists(kwargs, "http_proxy_url", _internal.Platform.HTTP_PROXY_URL.read(config_file))
return PlatformConfig(**kwargs)

Expand Down

0 comments on commit 3370a96

Please sign in to comment.