Skip to content

Commit

Permalink
Pass locally defined scopes to RemoteClientConfigStore
Browse files Browse the repository at this point in the history
  • Loading branch information
franco-bocci committed Mar 17, 2023
1 parent 34f80ba commit a1c82f7
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
10 changes: 7 additions & 3 deletions flytekit/clients/auth_helper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import ssl
import typing

import grpc
from flyteidl.service.auth_pb2 import OAuth2MetadataRequest, PublicClientAuthConfigRequest
Expand All @@ -24,8 +25,9 @@ class RemoteClientConfigStore(ClientConfigStore):
This class implements the ClientConfigStore that is served by the Flyte Server, that implements AuthMetadataService
"""

def __init__(self, secure_channel: grpc.Channel):
def __init__(self, secure_channel: grpc.Channel, scopes: typing.List[str] = None):
self._secure_channel = secure_channel
self._scopes = scopes

def get_client_config(self) -> ClientConfig:
"""
Expand All @@ -34,12 +36,14 @@ def get_client_config(self) -> ClientConfig:
metadata_service = AuthMetadataServiceStub(self._secure_channel)
public_client_config = metadata_service.GetPublicClientConfig(PublicClientAuthConfigRequest())
oauth2_metadata = metadata_service.GetOAuth2Metadata(OAuth2MetadataRequest())
# Use the ones defined locally (if any) or default to `public_client_config.scopes`
scopes = self._scopes or public_client_config.scopes
return ClientConfig(
token_endpoint=oauth2_metadata.token_endpoint,
authorization_endpoint=oauth2_metadata.authorization_endpoint,
redirect_uri=public_client_config.redirect_uri,
client_id=public_client_config.client_id,
scopes=public_client_config.scopes,
scopes=scopes,
header_key=public_client_config.authorization_metadata_key or None,
)

Expand Down Expand Up @@ -92,7 +96,7 @@ def upgrade_channel_to_authenticated(cfg: PlatformConfig, in_channel: grpc.Chann
:param in_channel: grpc.Channel Precreated channel
:return: grpc.Channel. New composite channel
"""
authenticator = get_authenticator(cfg, RemoteClientConfigStore(in_channel))
authenticator = get_authenticator(cfg, RemoteClientConfigStore(in_channel, cfg.scopes))
return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(authenticator))


Expand Down
15 changes: 15 additions & 0 deletions tests/flytekit/unit/clients/test_auth_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,21 @@ def test_remote_client_config_store(mock_auth_service: MagicMock):
assert ccfg is not None
assert ccfg.client_id == CLIENT_ID
assert ccfg.authorization_endpoint == OAUTH_AUTHORIZE
assert ccfg.scopes == ["offline", "all"]


@patch("flytekit.clients.auth_helper.AuthMetadataServiceStub")
def test_remote_client_config_store_with_custom_defined_scopes(mock_auth_service: MagicMock):
locally_defined_scopes = ["foo", "baz"]
ch = MagicMock()
cs = RemoteClientConfigStore(ch, locally_defined_scopes)
mock_auth_service.return_value = get_auth_service_mock()

ccfg = cs.get_client_config()
assert ccfg is not None
assert ccfg.client_id == CLIENT_ID
assert ccfg.authorization_endpoint == OAUTH_AUTHORIZE
assert ccfg.scopes == locally_defined_scopes


def get_client_config() -> ClientConfigStore:
Expand Down

0 comments on commit a1c82f7

Please sign in to comment.