From b22e5e3415c84617b49e0faf94de60abbeb619ab Mon Sep 17 00:00:00 2001 From: Franco Bocci <121866694+franco-bocci@users.noreply.github.com> Date: Fri, 24 Mar 2023 23:20:25 +0100 Subject: [PATCH] Pass locally defined scopes to RemoteClientConfigStore (#1553) Signed-off-by: franco-bocci --- flytekit/clients/auth/authenticator.py | 6 +++-- flytekit/clients/auth_helper.py | 1 + .../unit/clients/auth/test_authenticator.py | 25 ++++++++++++++++++- 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/flytekit/clients/auth/authenticator.py b/flytekit/clients/auth/authenticator.py index 183c1787cd..1fe0d9711c 100644 --- a/flytekit/clients/auth/authenticator.py +++ b/flytekit/clients/auth/authenticator.py @@ -163,13 +163,15 @@ def __init__( client_id: str, client_secret: str, cfg_store: ClientConfigStore, - header_key: str = None, + header_key: typing.Optional[str] = None, + scopes: typing.Optional[typing.List[str]] = None, ): if not client_id or not client_secret: raise ValueError("Client ID and Client SECRET both are required.") cfg = cfg_store.get_client_config() self._token_endpoint = cfg.token_endpoint - self._scopes = cfg.scopes + # Use scopes from `flytekit.configuration.PlatformConfig` if passed + self._scopes = scopes or cfg.scopes self._client_id = client_id self._client_secret = client_secret super().__init__(endpoint, cfg.header_key or header_key) diff --git a/flytekit/clients/auth_helper.py b/flytekit/clients/auth_helper.py index 41fc5c025f..3a5464fd6e 100644 --- a/flytekit/clients/auth_helper.py +++ b/flytekit/clients/auth_helper.py @@ -69,6 +69,7 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth client_id=cfg.client_id, client_secret=cfg.client_credentials_secret, cfg_store=cfg_store, + scopes=cfg.scopes, ) elif cfg_auth == AuthType.EXTERNAL_PROCESS or cfg_auth == AuthType.EXTERNALCOMMAND: client_cfg = None diff --git a/tests/flytekit/unit/clients/auth/test_authenticator.py b/tests/flytekit/unit/clients/auth/test_authenticator.py index 4c968cf0bd..5c1586970a 100644 --- a/tests/flytekit/unit/clients/auth/test_authenticator.py +++ b/tests/flytekit/unit/clients/auth/test_authenticator.py @@ -82,7 +82,7 @@ def test_get_token(mock_requests): @patch("flytekit.clients.auth.authenticator.requests") -def test_client_creds_authenticator(mock_requests): +def test_client_creds_authenticator_without_custom_scopes(mock_requests): authn = ClientCredentialsAuthenticator( ENDPOINT, client_id="client", client_secret="secret", cfg_store=static_cfg_store ) @@ -92,4 +92,27 @@ def test_client_creds_authenticator(mock_requests): response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") mock_requests.post.return_value = response authn.refresh_credentials() + expected_scopes = static_cfg_store.get_client_config().scopes + + assert authn._creds + assert authn._scopes == expected_scopes + + +@patch("flytekit.clients.auth.authenticator.requests") +def test_client_creds_authenticator_with_custom_scopes(mock_requests): + expected_scopes = ["foo", "baz"] + authn = ClientCredentialsAuthenticator( + ENDPOINT, + client_id="client", + client_secret="secret", + cfg_store=static_cfg_store, + scopes=expected_scopes, + ) + response = MagicMock() + response.status_code = 200 + response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") + mock_requests.post.return_value = response + authn.refresh_credentials() + assert authn._creds + assert authn._scopes == expected_scopes