Skip to content

Commit

Permalink
add extra tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Grace committed Sep 20, 2020
1 parent 40bb23c commit 0825159
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 21 deletions.
4 changes: 2 additions & 2 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,7 +1425,7 @@ def fave_slices( # pylint: disable=no-self-use
@api
@has_access_api
@event_logger.log_this
@expose("/dashboard/<dashboard_id>/stop/", methods=["POST"])
@expose("/dashboard/<int:dashboard_id>/stop/", methods=["POST"])
def stop_dashboard_queries( # pylint: disable=no-self-use
self, dashboard_id: int
) -> FlaskResponse:
Expand All @@ -1437,7 +1437,7 @@ def stop_dashboard_queries( # pylint: disable=no-self-use
for dbid in database_ids:
mydb = db.session.query(models.Database).get(dbid)
if mydb:
mydb.db_engine_spec.stop_queries(username, dashboard_id)
mydb.db_engine_spec.stop_queries(username, int(dashboard_id))

return Response(status=200)

Expand Down
25 changes: 12 additions & 13 deletions superset/views/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,21 +222,20 @@ def get_database_ids(dashboard_id: int) -> List[int]:
"""
dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).one()
slices = dashboard.slices
datasource_ids = set()
database_ids = set()
datasource_ids: Set[int] = set()
database_ids: Set[int] = set()

for slc in slices:
datasource_type = slc.datasource.type
datasource_id = slc.datasource.id

if datasource_id and datasource_type:
ds_class = ConnectorRegistry.sources.get(datasource_type)
datasource = db.session.query(ds_class).filter_by(id=datasource_id).one()
if datasource and datasource_id not in datasource_ids:
datasource_ids.add(datasource_id)
database = datasource.database
if database:
database_ids.add(database.id)
datasource = slc.datasource
if (
datasource
and datasource.type == "table"
and datasource.id not in datasource_ids
):
datasource_ids.add(datasource.id)
database = datasource.database
if database:
database_ids.add(database.id)

return list(database_ids)

Expand Down
2 changes: 1 addition & 1 deletion tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,7 +1251,7 @@ def test_stop_dashboard_queries(self):

self.assertTrue(is_feature_enabled("STOP_DASHBOARD_PENDING_QUERIES"))
self.assertEqual(resp.status_code, 200)
mock_stop_queries.assert_called_once()
mock_stop_queries.assert_called_once_with(username, dashboard.id)


if __name__ == "__main__":
Expand Down
77 changes: 72 additions & 5 deletions tests/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

import tests.test_app
from superset import app, db, security_manager
from superset.connectors.base.models import BaseDatasource
from superset.exceptions import CertificateException, SupersetException
from superset.models.core import Database, Log
from superset.models.dashboard import Dashboard
Expand All @@ -46,6 +47,7 @@
get_form_data_token,
get_iterable,
get_email_address_list,
get_example_database,
get_or_create_db,
get_since_until,
get_stacktrace,
Expand Down Expand Up @@ -1142,9 +1144,74 @@ def test_get_database_ids(self) -> None:
dash_id = world_health.id
database_ids = get_database_ids(dash_id)
assert len(database_ids) == 1
assert database_ids == [get_example_database().id]

def test_get_database_ids_empty_dash(self) -> None:
# test dash with no slice
dashboard = Dashboard(dashboard_title="no slices", id=101, slices=[])
with patch("superset.db.session.query") as mock_query:
mock_query.return_value.filter_by.return_value.one.return_value = dashboard
database_ids = get_database_ids(dashboard.id)
assert database_ids == []

def test_get_database_ids_multiple_databases(self) -> None:
# test dash with 2 databases
datasource_1 = Mock()
datasource_1.type = "table"
datasource_1.datasource_name = "table_datasource_1"
datasource_1.database = Mock()

datasource_2 = Mock()
datasource_2.type = "table"
datasource_2.datasource_name = "table_datasource_2"
datasource_2.database = Mock()

slices = [
Slice(
datasource_id=datasource_1.id,
datasource_type=datasource_1.type,
datasource_name=datasource_1.datasource_name,
slice_name="slice_name_1",
),
Slice(
datasource_id=datasource_2.id,
datasource_type=datasource_2.type,
datasource_name=datasource_2.datasource_name,
slice_name="slice_name_2",
),
]
dashboard = Dashboard(dashboard_title="with 2 slices", id=102, slices=slices)
with patch("superset.db.session.query") as mock_query:
mock_query.return_value.filter_by.return_value.one.return_value = dashboard
mock_query.return_value.filter_by.return_value.first.side_effect = [
datasource_1,
datasource_2,
]
database_ids = get_database_ids(dashboard.id)
self.assertCountEqual(
database_ids, [datasource_1.database.id, datasource_2.database.id]
)

world_slice = (
db.session.query(Slice).filter_by(slice_name="World's Population").one()
)
database_id = world_slice.datasource.database.id
assert database_ids == [database_id]
def test_get_database_ids_druid(self) -> None:
druid_datasource = Mock()
druid_datasource.type = "druid"
druid_datasource.datasource_name = "druid_datasource_1"
druid_datasource.cluster = Mock()

slices = [
Slice(
datasource_id=druid_datasource.id,
datasource_type=druid_datasource.type,
datasource_name=druid_datasource.datasource_name,
slice_name="slice_name_1",
),
]
dashboard = Dashboard(dashboard_title="druid dash", id=103, slices=slices)
with patch("superset.db.session.query") as mock_query:
mock_query.return_value.filter_by.return_value.one.return_value = dashboard
mock_query.return_value.filter_by.return_value.first.return_value = (
druid_datasource
)
database_ids = get_database_ids(dashboard.id)
# druid slice has no database id
assert database_ids == []

0 comments on commit 0825159

Please sign in to comment.