diff --git a/tests/flytekit/unit/clients/auth/test_authenticator.py b/tests/flytekit/unit/clients/auth/test_authenticator.py index f23019f607..82ffa654dd 100644 --- a/tests/flytekit/unit/clients/auth/test_authenticator.py +++ b/tests/flytekit/unit/clients/auth/test_authenticator.py @@ -118,9 +118,17 @@ def test_device_flow_authenticator(poll_mock: MagicMock, device_mock: MagicMock, assert authn._creds -@patch("flytekit.clients.auth.token_client.requests") -def test_client_creds_authenticator_with_custom_scopes(mock_requests): +@patch("flytekit.clients.auth.token_client.requests.Session") +def test_client_creds_authenticator_with_custom_scopes(mock_session): expected_scopes = ["foo", "baz"] + + session = MagicMock() + response = MagicMock() + response.status_code = 200 + response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") + session.post.return_value = response + mock_session.return_value = session + authn = ClientCredentialsAuthenticator( ENDPOINT, client_id="client", @@ -129,11 +137,9 @@ def test_client_creds_authenticator_with_custom_scopes(mock_requests): scopes=expected_scopes, verify=True, ) - response = MagicMock() - response.status_code = 200 - response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") - mock_requests.post.return_value = response + authn.refresh_credentials() assert authn._creds + assert authn._creds.access_token == "abc" assert authn._scopes == expected_scopes