Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: pass custom Client object to dbapi #911

Merged
merged 5 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def connect(
credentials=None,
pool=None,
user_agent=None,
client=None,
):
"""Creates a connection to a Google Cloud Spanner database.

Expand Down Expand Up @@ -529,25 +530,31 @@ def connect(
:param user_agent: (Optional) User agent to be used with this connection's
requests.

:type client: Concrete subclass of
:class:`~google.cloud.spanner_v1.Client`.
:param user_agent: (Optional) Custom user provided Client Object
asthamohta marked this conversation as resolved.
Show resolved Hide resolved

:rtype: :class:`google.cloud.spanner_dbapi.connection.Connection`
:returns: Connection object associated with the given Google Cloud Spanner
resource.
"""

client_info = ClientInfo(
user_agent=user_agent or DEFAULT_USER_AGENT,
python_version=PY_VERSION,
client_library_version=spanner.__version__,
)

if isinstance(credentials, str):
client = spanner.Client.from_service_account_json(
credentials, project=project, client_info=client_info
if client is None:
client_info = ClientInfo(
user_agent=user_agent or DEFAULT_USER_AGENT,
python_version=PY_VERSION,
client_library_version=spanner.__version__,
)
if isinstance(credentials, str):
client = spanner.Client.from_service_account_json(
credentials, project=project, client_info=client_info
)
else:
client = spanner.Client(
project=project, credentials=credentials, client_info=client_info
)
else:
client = spanner.Client(
project=project, credentials=credentials, client_info=client_info
)
if project is not None and client.project != project:
raise ValueError("project in url does not match client object project")

instance = client.instance(instance_id)
conn = Connection(instance, instance.database(database_id, pool=pool))
Expand Down
46 changes: 46 additions & 0 deletions tests/unit/spanner_dbapi/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import mock
import unittest
import warnings
import pytest

PROJECT = "test-project"
INSTANCE = "test-instance"
Expand Down Expand Up @@ -915,7 +916,52 @@ def test_request_priority(self):
sql, params, param_types=param_types, request_options=None
)

@mock.patch("google.cloud.spanner_v1.Client")
def test_custom_client_connection(self, mock_client):
from google.cloud.spanner_dbapi import connect

client = _Client()
connection = connect("test-instance", "test-database", client=client)
self.assertTrue(connection.instance._client == client)

@mock.patch("google.cloud.spanner_v1.Client")
def test_invalid_custom_client_connection(self, mock_client):
from google.cloud.spanner_dbapi import connect

client = _Client()
with pytest.raises(ValueError):
connect(
"test-instance",
"test-database",
project="invalid_project",
client=client,
)


def exit_ctx_func(self, exc_type, exc_value, traceback):
"""Context __exit__ method mock."""
pass


class _Client(object):
def __init__(self, project="project_id"):
self.project = project
self.project_name = "projects/" + self.project

def instance(self, instance_id="instance_id"):
return _Instance(name=instance_id, client=self)


class _Instance(object):
def __init__(self, name="instance_id", client=None):
self.name = name
self._client = client

def database(self, database_id="database_id", pool=None):
return _Database(database_id, pool)


class _Database(object):
def __init__(self, database_id="database_id", pool=None):
self.name = database_id
self.pool = pool