Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix metadata reflection without a default dataset #1089

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 25 additions & 34 deletions sqlalchemy_bigquery/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@
import datetime
from decimal import Decimal
import random
import operator
import uuid

from google import auth
import google.api_core.exceptions
from google.cloud.bigquery import dbapi
from google.cloud.bigquery.table import (
RangePartitioning,
Expand Down Expand Up @@ -1054,11 +1052,6 @@ def dbapi(cls):
def import_dbapi(cls):
return dbapi

@staticmethod
def _build_formatted_table_id(table):
"""Build '<dataset_id>.<table_id>' string using given table."""
return "{}.{}".format(table.reference.dataset_id, table.table_id)

@staticmethod
def _add_default_dataset_to_job_config(job_config, project_id, dataset_id):
# If dataset_id is set, then we know the job_config isn't None
Expand Down Expand Up @@ -1107,36 +1100,34 @@ def create_connect_args(self, url):
)
return ([], {"client": client})

def _get_table_or_view_names(self, connection, item_types, schema=None):
current_schema = schema or self.dataset_id
get_table_name = (
self._build_formatted_table_id
if self.dataset_id is None
else operator.attrgetter("table_id")
)
def _get_default_schema_name(self, connection) -> str:
return connection.dialect.dataset_id

def _get_table_or_view_names(self, connection, item_types, schema=None):
client = connection.connection._client
datasets = client.list_datasets()

result = []
for dataset in datasets:
if current_schema is not None and current_schema != dataset.dataset_id:
continue

try:
tables = client.list_tables(
dataset.reference, page_size=self.list_tables_page_size
# `schema=None` means to search the default schema. If one isn't set in the
# connection string, then we have nothing to search so return an empty list.
#
# When using Alembic with `include_schemas=False`, it expects to compare to a
# single schema. If `include_schemas=True`, it will enumerate all schemas and
# then call `get_table_names`/`get_view_names` for each schema.
current_schema = schema or self.default_schema_name
if current_schema is None:
return []
try:
return [
table.table_id
for table in client.list_tables(
current_schema, page_size=self.list_tables_page_size
)
for table in tables:
if table.table_type in item_types:
result.append(get_table_name(table))
except google.api_core.exceptions.NotFound:
# It's possible that the dataset was deleted between when we
# fetched the list of datasets and when we try to list the
# tables from it. See:
# https://github.com/googleapis/python-bigquery-sqlalchemy/issues/105
pass
return result
if table.table_type in item_types
]
except NotFound:
# It's possible that the dataset was deleted between when we
# fetched the list of datasets and when we try to list the
# tables from it. See:
# https://github.com/googleapis/python-bigquery-sqlalchemy/issues/105
return []

@staticmethod
def _split_table_name(full_table_name):
Expand Down
34 changes: 13 additions & 21 deletions tests/system/test_sqlalchemy_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,18 +366,6 @@ def test_reflect_dataset_does_not_exist(engine):
)


def test_tables_list(engine, engine_using_test_dataset, bigquery_dataset):
tables = sqlalchemy.inspect(engine).get_table_names()
assert f"{bigquery_dataset}.sample" in tables
assert f"{bigquery_dataset}.sample_one_row" in tables
assert f"{bigquery_dataset}.sample_view" not in tables

tables = sqlalchemy.inspect(engine_using_test_dataset).get_table_names()
assert "sample" in tables
assert "sample_one_row" in tables
assert "sample_view" not in tables


def test_group_by(session, table, session_using_test_dataset, table_using_test_dataset):
"""labels in SELECT clause should be correclty formatted (dots are replaced with underscores)"""
for session, table in [
Expand Down Expand Up @@ -612,14 +600,15 @@ def test_schemas_names(inspector, inspector_using_test_dataset, bigquery_dataset
assert f"{bigquery_dataset}" in datasets


def test_table_names_in_schema(
inspector, inspector_using_test_dataset, bigquery_dataset
):
def test_table_names(inspector, inspector_using_test_dataset, bigquery_dataset):
tables = inspector.get_table_names()
assert not tables

tables = inspector.get_table_names(bigquery_dataset)
assert f"{bigquery_dataset}.sample" in tables
assert f"{bigquery_dataset}.sample_one_row" in tables
assert f"{bigquery_dataset}.sample_dml_empty" in tables
assert f"{bigquery_dataset}.sample_view" not in tables
assert "sample" in tables
assert "sample_one_row" in tables
assert "sample_dml_empty" in tables
assert "sample_view" not in tables
assert len(tables) == 3

tables = inspector_using_test_dataset.get_table_names()
Expand All @@ -632,8 +621,11 @@ def test_table_names_in_schema(

def test_view_names(inspector, inspector_using_test_dataset, bigquery_dataset):
view_names = inspector.get_view_names()
assert f"{bigquery_dataset}.sample_view" in view_names
assert f"{bigquery_dataset}.sample" not in view_names
assert not view_names

view_names = inspector.get_view_names(bigquery_dataset)
assert "sample_view" in view_names
assert "sample" not in view_names

view_names = inspector_using_test_dataset.get_view_names()
assert "sample_view" in view_names
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/fauxdbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,8 @@ def list_tables(self, dataset, page_size):
google.cloud.bigquery.table.TableListItem(
dict(
tableReference=dict(
projectId=dataset.project,
datasetId=dataset.dataset_id,
projectId="myproject",
datasetId=dataset,
tableId=row["name"],
),
type=row["type"].upper(),
Expand Down
89 changes: 38 additions & 51 deletions tests/unit/test_sqlalchemy_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,83 +65,70 @@ def table_item(dataset_id, table_id, type_="TABLE"):


@pytest.mark.parametrize(
["datasets_list", "tables_lists", "expected"],
["dataset", "tables_list", "expected"],
[
([], [], []),
([dataset_item("dataset_1")], [[]], []),
(None, [], []),
("dataset", [], []),
(
[dataset_item("dataset_1"), dataset_item("dataset_2")],
"dataset",
[
[table_item("dataset_1", "d1t1"), table_item("dataset_1", "d1t2")],
[
table_item("dataset_2", "d2t1"),
table_item("dataset_2", "d2view", type_="VIEW"),
table_item("dataset_2", "d2ext", type_="EXTERNAL"),
table_item("dataset_2", "d2mv", type_="MATERIALIZED_VIEW"),
],
table_item("dataset", "t1"),
table_item("dataset", "view", type_="VIEW"),
table_item("dataset", "ext", type_="EXTERNAL"),
table_item("dataset", "mv", type_="MATERIALIZED_VIEW"),
],
["dataset_1.d1t1", "dataset_1.d1t2", "dataset_2.d2t1", "dataset_2.d2ext"],
["t1", "ext"],
),
(
[dataset_item("dataset_1"), dataset_item("dataset_deleted")],
[
[table_item("dataset_1", "d1t1")],
google.api_core.exceptions.NotFound("dataset_deleted"),
],
["dataset_1.d1t1"],
"dataset",
google.api_core.exceptions.NotFound("dataset_deleted"),
[],
),
],
)
def test_get_table_names(
engine_under_test, mock_bigquery_client, datasets_list, tables_lists, expected
engine_under_test, mock_bigquery_client, dataset, tables_list, expected
):
mock_bigquery_client.list_datasets.return_value = datasets_list
mock_bigquery_client.list_tables.side_effect = tables_lists
table_names = sqlalchemy.inspect(engine_under_test).get_table_names()
mock_bigquery_client.list_datasets.assert_called_once()
assert mock_bigquery_client.list_tables.call_count == len(datasets_list)
mock_bigquery_client.list_tables.side_effect = [tables_list]
table_names = sqlalchemy.inspect(engine_under_test).get_table_names(schema=dataset)
if dataset:
mock_bigquery_client.list_tables.assert_called_once()
else:
mock_bigquery_client.list_tables.assert_not_called()
assert list(sorted(table_names)) == list(sorted(expected))


@pytest.mark.parametrize(
["datasets_list", "tables_lists", "expected"],
["dataset", "tables_list", "expected"],
[
([], [], []),
([dataset_item("dataset_1")], [[]], []),
(None, [], []),
("dataset", [], []),
(
[dataset_item("dataset_1"), dataset_item("dataset_2")],
"dataset",
[
[
table_item("dataset_1", "d1t1"),
table_item("dataset_1", "d1view", type_="VIEW"),
],
[
table_item("dataset_2", "d2t1"),
table_item("dataset_2", "d2view", type_="VIEW"),
table_item("dataset_2", "d2ext", type_="EXTERNAL"),
table_item("dataset_2", "d2mv", type_="MATERIALIZED_VIEW"),
],
table_item("dataset", "t1"),
table_item("dataset", "view", type_="VIEW"),
table_item("dataset", "ext", type_="EXTERNAL"),
table_item("dataset", "mv", type_="MATERIALIZED_VIEW"),
],
["dataset_1.d1view", "dataset_2.d2view", "dataset_2.d2mv"],
["view", "mv"],
),
(
[dataset_item("dataset_1"), dataset_item("dataset_deleted")],
[
[table_item("dataset_1", "d1view", type_="VIEW")],
google.api_core.exceptions.NotFound("dataset_deleted"),
],
["dataset_1.d1view"],
"dataset_deleted",
google.api_core.exceptions.NotFound("dataset_deleted"),
[],
),
],
)
def test_get_view_names(
inspector_under_test, mock_bigquery_client, datasets_list, tables_lists, expected
inspector_under_test, mock_bigquery_client, dataset, tables_list, expected
):
mock_bigquery_client.list_datasets.return_value = datasets_list
mock_bigquery_client.list_tables.side_effect = tables_lists
view_names = inspector_under_test.get_view_names()
mock_bigquery_client.list_datasets.assert_called_once()
assert mock_bigquery_client.list_tables.call_count == len(datasets_list)
mock_bigquery_client.list_tables.side_effect = [tables_list]
view_names = inspector_under_test.get_view_names(schema=dataset)
if dataset:
mock_bigquery_client.list_tables.assert_called_once()
else:
mock_bigquery_client.list_tables.assert_not_called()
assert list(sorted(view_names)) == list(sorted(expected))


Expand Down