Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add arg for specifying credentials #226

Merged
merged 10 commits into from
Jan 4, 2022
13 changes: 10 additions & 3 deletions google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
IPTypes,
)
from google.cloud.sql.connector.utils import generate_keys

from google.auth.credentials import Credentials
from threading import Thread
from typing import Any, Dict
from typing import Any, Dict, Optional

logger = logging.getLogger(name=__name__)

Expand All @@ -43,16 +43,21 @@ class Connector:
Enables IAM based authentication (Postgres only).

:type timeout: int
:param timeout:
:param timeout
The time limit for a connection before raising a TimeoutError.

:type credentials: google.auth.credentials.Credentials
:param credentials
Credentials object used to authenticate connections to Cloud SQL server.
If not specified, Application Default Credentials are used.
"""

def __init__(
self,
ip_types: IPTypes = IPTypes.PUBLIC,
enable_iam_auth: bool = False,
timeout: int = 30,
credentials: Optional[Credentials] = None,
) -> None:
self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
self._thread: Thread = Thread(target=self._loop.run_forever, daemon=True)
Expand All @@ -66,6 +71,7 @@ def __init__(
self._timeout = timeout
self._enable_iam_auth = enable_iam_auth
self._ip_types = ip_types
self._credentials = credentials

def connect(
self, instance_connection_string: str, driver: str, **kwargs: Any
Expand Down Expand Up @@ -112,6 +118,7 @@ def connect(
driver,
self._keys,
self._loop,
self._credentials,
enable_iam_auth,
)
self._instances[instance_connection_string] = icm
Expand Down
50 changes: 40 additions & 10 deletions google/cloud/sql/connector/instance_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import datetime
from enum import Enum
import google.auth
from google.auth.credentials import Credentials
from google.auth.credentials import Credentials, with_scopes_if_required
import google.auth.transport.requests
import OpenSSL
import platform
Expand Down Expand Up @@ -117,6 +117,15 @@ def __init__(self, *args: Any) -> None:
super(PlatformNotSupportedError, self).__init__(self, *args)


class CredentialsTypeError(Exception):
"""
Raised when credentials parameter is not proper type.
"""

def __init__(self, *args: Any) -> None:
super(CredentialsTypeError, self).__init__(self, *args)


class InstanceMetadata:
ip_addrs: Dict[str, Any]
context: ssl.SSLContext
Expand Down Expand Up @@ -177,6 +186,11 @@ class InstanceConnectionManager:
The user agent string to append to SQLAdmin API requests
:type user_agent_string: str

:type credentials: google.auth.credentials.Credentials
:param credentials
Credentials object used to authenticate connections to Cloud SQL server.
If not specified, Application Default Credentials are used.

:param enable_iam_auth
Enables IAM based authentication for Postgres instances.
:type enable_iam_auth: bool
Expand Down Expand Up @@ -229,6 +243,7 @@ def __init__(
driver_name: str,
keys: concurrent.futures.Future,
loop: asyncio.AbstractEventLoop,
credentials: Optional[Credentials] = None,
enable_iam_auth: bool = False,
) -> None:
# Validate connection string
Expand All @@ -250,7 +265,14 @@ def __init__(
self._user_agent_string = f"{APPLICATION_NAME}/{version}+{driver_name}"
self._loop = loop
self._keys = asyncio.wrap_future(keys, loop=self._loop)
self._auth_init()
# validate credentials type
if not isinstance(credentials, Credentials) and credentials is not None:
raise CredentialsTypeError(
"Arg credentials must be type 'google.auth.credentials.Credentials' "
"or None (to use Application Default Credentials)"
)

self._auth_init(credentials)

self._refresh_rate_limiter = AsyncRateLimiter(
max_capacity=2, rate=1 / 30, loop=self._loop
Expand Down Expand Up @@ -343,17 +365,25 @@ async def _get_instance_data(self) -> InstanceMetadata:
self._enable_iam_auth,
)

def _auth_init(self) -> None:
def _auth_init(self, credentials: Optional[Credentials]) -> None:
"""Creates and assigns a Google Python API service object for
Google Cloud SQL Admin API.
"""

credentials, project = google.auth.default(
scopes=[
"https://www.googleapis.com/auth/sqlservice.admin",
"https://www.googleapis.com/auth/cloud-platform",
]
)
:type credentials: google.auth.credentials.Credentials
:param credentials
Credentials object used to authenticate connections to Cloud SQL server.
If not specified, Application Default Credentials are used.
"""
scopes = [
"https://www.googleapis.com/auth/sqlservice.admin",
"https://www.googleapis.com/auth/cloud-platform",
]
# if Credentials object is passed in, use for authentication
if isinstance(credentials, Credentials):
credentials = with_scopes_if_required(credentials, scopes=scopes)
# otherwise use application default credentials
else:
credentials, project = google.auth.default(scopes=scopes)

self._credentials = credentials

Expand Down
55 changes: 55 additions & 0 deletions tests/unit/test_instance_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,24 @@
"""

import asyncio
from unittest.mock import Mock, patch
import datetime
from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter
from typing import Any
import pytest # noqa F401 Needed to run the tests
from google.auth.credentials import Credentials
from google.cloud.sql.connector.instance_connection_manager import (
InstanceConnectionManager,
CredentialsTypeError,
)
from google.cloud.sql.connector.utils import generate_keys


@pytest.fixture
def mock_credentials() -> Credentials:
return Mock(spec=Credentials)


@pytest.fixture
def icm(
async_loop: asyncio.AbstractEventLoop, connect_string: str
Expand Down Expand Up @@ -73,6 +81,21 @@ def test_InstanceConnectionManager_init(async_loop: asyncio.AbstractEventLoop) -
)


def test_InstanceConnectionManager_init_bad_credentials(
async_loop: asyncio.AbstractEventLoop,
) -> None:
"""
Test to check whether the __init__ method of InstanceConnectionManager
throws proper error for bad credentials arg type.
"""
connect_string = "test-project:test-region:test-instance"
keys = asyncio.run_coroutine_threadsafe(generate_keys(), async_loop)
with pytest.raises(CredentialsTypeError):
assert InstanceConnectionManager(
connect_string, "pymysql", keys, async_loop, credentials=1
)


@pytest.mark.asyncio
async def test_perform_refresh_replaces_result(
icm: InstanceConnectionManager, test_rate_limiter: AsyncRateLimiter
Expand Down Expand Up @@ -171,3 +194,35 @@ async def test_force_refresh_cancels_pending_refresh(

assert pending_refresh.cancelled() is True
assert isinstance(icm._current.result(), MockMetadata)


def test_auth_init_with_credentials_object(
icm: InstanceConnectionManager, mock_credentials: Credentials
) -> None:
"""
Test that InstanceConnectionManager's _auth_init initializes _credentials
when passed a google.auth.credentials.Credentials object.
"""
setattr(icm, "_credentials", None)
with patch(
"google.cloud.sql.connector.instance_connection_manager.with_scopes_if_required"
) as mock_auth:
mock_auth.return_value = mock_credentials
icm._auth_init(credentials=mock_credentials)
assert isinstance(icm._credentials, Credentials)
mock_auth.assert_called_once()


def test_auth_init_with_default_credentials(
icm: InstanceConnectionManager, mock_credentials: Credentials
) -> None:
"""
Test that InstanceConnectionManager's _auth_init initializes _credentials
with application default credentials when credentials are not specified.
"""
setattr(icm, "_credentials", None)
with patch("google.auth.default") as mock_auth:
mock_auth.return_value = mock_credentials, None
icm._auth_init(credentials=None)
assert isinstance(icm._credentials, Credentials)
mock_auth.assert_called_once()