diff --git a/flytekit/clients/auth_helper.py b/flytekit/clients/auth_helper.py index 41fc5c025fb..a3d97f9214d 100644 --- a/flytekit/clients/auth_helper.py +++ b/flytekit/clients/auth_helper.py @@ -1,5 +1,6 @@ import logging import ssl +import typing import grpc from flyteidl.service.auth_pb2 import OAuth2MetadataRequest, PublicClientAuthConfigRequest @@ -24,8 +25,9 @@ class RemoteClientConfigStore(ClientConfigStore): This class implements the ClientConfigStore that is served by the Flyte Server, that implements AuthMetadataService """ - def __init__(self, secure_channel: grpc.Channel): + def __init__(self, secure_channel: grpc.Channel, scopes: typing.List[str] = None): self._secure_channel = secure_channel + self._scopes = scopes def get_client_config(self) -> ClientConfig: """ @@ -34,12 +36,14 @@ def get_client_config(self) -> ClientConfig: metadata_service = AuthMetadataServiceStub(self._secure_channel) public_client_config = metadata_service.GetPublicClientConfig(PublicClientAuthConfigRequest()) oauth2_metadata = metadata_service.GetOAuth2Metadata(OAuth2MetadataRequest()) + # Use the ones defined locally (if any) or default to `public_client_config.scopes` + scopes = self._scopes or public_client_config.scopes return ClientConfig( token_endpoint=oauth2_metadata.token_endpoint, authorization_endpoint=oauth2_metadata.authorization_endpoint, redirect_uri=public_client_config.redirect_uri, client_id=public_client_config.client_id, - scopes=public_client_config.scopes, + scopes=scopes, header_key=public_client_config.authorization_metadata_key or None, ) @@ -92,7 +96,7 @@ def upgrade_channel_to_authenticated(cfg: PlatformConfig, in_channel: grpc.Chann :param in_channel: grpc.Channel Precreated channel :return: grpc.Channel. New composite channel """ - authenticator = get_authenticator(cfg, RemoteClientConfigStore(in_channel)) + authenticator = get_authenticator(cfg, RemoteClientConfigStore(in_channel, cfg.scopes)) return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(authenticator)) diff --git a/tests/flytekit/unit/clients/test_auth_helper.py b/tests/flytekit/unit/clients/test_auth_helper.py index 8f14de730e3..4c5fd9a7191 100644 --- a/tests/flytekit/unit/clients/test_auth_helper.py +++ b/tests/flytekit/unit/clients/test_auth_helper.py @@ -64,6 +64,21 @@ def test_remote_client_config_store(mock_auth_service: MagicMock): assert ccfg is not None assert ccfg.client_id == CLIENT_ID assert ccfg.authorization_endpoint == OAUTH_AUTHORIZE + assert ccfg.scopes == ["offline", "all"] + + +@patch("flytekit.clients.auth_helper.AuthMetadataServiceStub") +def test_remote_client_config_store_with_custom_defined_scopes(mock_auth_service: MagicMock): + locally_defined_scopes = ["foo", "baz"] + ch = MagicMock() + cs = RemoteClientConfigStore(ch, locally_defined_scopes) + mock_auth_service.return_value = get_auth_service_mock() + + ccfg = cs.get_client_config() + assert ccfg is not None + assert ccfg.client_id == CLIENT_ID + assert ccfg.authorization_endpoint == OAUTH_AUTHORIZE + assert ccfg.scopes == locally_defined_scopes def get_client_config() -> ClientConfigStore: