Skip to content

Commit

Permalink
fix(dataset): create ES-View dataset raise exception apache#16623 (ap…
Browse files Browse the repository at this point in the history
…ache#16624)

* fix(dataset): create es-view dataset raise exception apache#16623

* fix(database): fix has_view logic

* refactor(database): fix logic

* style(lint): remove unused typing

* fix(test): add test case

* fix(test): fix test case
  • Loading branch information
aniaan authored and Emmanuel Bavoux committed Nov 14, 2021
1 parent d448237 commit 709871a
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 2 deletions.
6 changes: 5 additions & 1 deletion superset/connectors/sqla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ def get_physical_table_metadata(
# ensure empty schema
_schema_name = schema_name if schema_name else None
# Table does not exist or is not visible to a connection.
if not database.has_table_by_name(table_name, schema=_schema_name):

if not (
database.has_table_by_name(table_name=table_name, schema=_schema_name)
or database.has_view_by_name(view_name=table_name, schema=_schema_name)
):
raise NoSuchTableError

cols = database.get_columns(table_name, schema=_schema_name)
Expand Down
24 changes: 23 additions & 1 deletion superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
Table,
Text,
)
from sqlalchemy.engine import Dialect, Engine, url
from sqlalchemy.engine import Connection, Dialect, Engine, url
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import make_url, URL
from sqlalchemy.exc import ArgumentError
Expand Down Expand Up @@ -721,6 +721,28 @@ def has_table_by_name(self, table_name: str, schema: Optional[str] = None) -> bo
engine = self.get_sqla_engine()
return engine.has_table(table_name, schema)

@classmethod
def _has_view(
cls,
conn: Connection,
dialect: Dialect,
view_name: str,
schema: Optional[str] = None,
) -> bool:
view_names: List[str] = []
try:
view_names = dialect.get_view_names(connection=conn, schema=schema)
except Exception as ex: # pylint: disable=broad-except
logger.warning(ex)
return view_name in view_names

def has_view(self, view_name: str, schema: Optional[str] = None) -> bool:
engine = self.get_sqla_engine()
return engine.run_callable(self._has_view, engine.dialect, view_name, schema)

def has_view_by_name(self, view_name: str, schema: Optional[str] = None) -> bool:
return self.has_view(view_name=view_name, schema=schema)

@memoized
def get_dialect(self) -> Dialect:
sqla_url = url.make_url(self.sqlalchemy_uri_decrypted)
Expand Down
43 changes: 43 additions & 0 deletions tests/integration_tests/datasets/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,49 @@ def test_create_dataset_validate_tables_exists(self):
rv = self.post_assert_metric(uri, table_data, "post")
assert rv.status_code == 422

@patch("superset.models.core.Database.get_columns")
@patch("superset.models.core.Database.has_table_by_name")
@patch("superset.models.core.Database.get_table")
def test_create_dataset_validate_view_exists(
self, mock_get_table, mock_has_table_by_name, mock_get_columns
):
"""
Dataset API: Test create dataset validate view exists
"""

mock_get_columns.return_value = [
{"name": "col", "type": "VARCHAR", "type_generic": None, "is_dttm": None,}
]

mock_has_table_by_name.return_value = False
mock_get_table.return_value = None

example_db = get_example_database()
engine = example_db.get_sqla_engine()
dialect = engine.dialect

with patch.object(
dialect, "get_view_names", wraps=dialect.get_view_names
) as patch_get_view_names:
patch_get_view_names.return_value = ["test_case_view"]

self.login(username="admin")
table_data = {
"database": example_db.id,
"schema": "",
"table_name": "test_case_view",
}

uri = "api/v1/dataset/"
rv = self.post_assert_metric(uri, table_data, "post")
assert rv.status_code == 201

# cleanup
data = json.loads(rv.data.decode("utf-8"))
uri = f'api/v1/dataset/{data.get("id")}'
rv = self.client.delete(uri)
assert rv.status_code == 200

@patch("superset.datasets.dao.DatasetDAO.create")
def test_create_dataset_sqlalchemy_error(self, mock_dao_create):
"""
Expand Down

0 comments on commit 709871a

Please sign in to comment.