diff --git a/superset/app.py b/superset/app.py index 5280922a1d1b0..0545002fac049 100644 --- a/superset/app.py +++ b/superset/app.py @@ -143,7 +143,7 @@ def init_views(self) -> None: from superset.databases.api import DatabaseRestApi from superset.datasets.api import DatasetRestApi from superset.queries.api import QueryRestApi - from superset.queries.savedqueries.api import SavedQueryRestApi + from superset.queries.saved_queries.api import SavedQueryRestApi from superset.views.access_requests import AccessRequestsModelView from superset.views.alerts import ( AlertLogModelView, diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 414dec48b813c..c101b772c0258 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -17,7 +17,7 @@ """A collection of ORM sqlalchemy models for SQL Lab""" import re from datetime import datetime -from typing import Any, Dict +from typing import Any, Dict, List import simplejson as json import sqlalchemy as sqla @@ -39,7 +39,7 @@ from superset import security_manager from superset.models.helpers import AuditMixinNullable, ExtraJSONMixin from superset.models.tags import QueryUpdater -from superset.sql_parse import CtasMethod +from superset.sql_parse import CtasMethod, ParsedQuery, Table from superset.utils.core import QueryStatus, user_label @@ -203,6 +203,10 @@ def sqlalchemy_uri(self) -> URL: def url(self) -> str: return "/superset/sqllab?savedQueryId={0}".format(self.id) + @property + def sql_tables(self) -> List[Table]: + return list(ParsedQuery(self.sql).tables) + class TabState(Model, AuditMixinNullable, ExtraJSONMixin): diff --git a/superset/queries/savedqueries/__init__.py b/superset/queries/saved_queries/__init__.py similarity index 100% rename from superset/queries/savedqueries/__init__.py rename to superset/queries/saved_queries/__init__.py diff --git a/superset/queries/saved_queries/api.py b/superset/queries/saved_queries/api.py new file mode 100644 index 0000000000000..81204a8b1c98e --- /dev/null +++ b/superset/queries/saved_queries/api.py @@ -0,0 +1,168 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Any + +from flask import g, Response +from flask_appbuilder.api import expose, protect, rison, safe +from flask_appbuilder.models.sqla.interface import SQLAInterface +from flask_babel import ngettext + +from superset.constants import RouteMethod +from superset.databases.filters import DatabaseFilter +from superset.models.sql_lab import SavedQuery +from superset.queries.saved_queries.commands.bulk_delete import ( + BulkDeleteSavedQueryCommand, +) +from superset.queries.saved_queries.commands.exceptions import ( + SavedQueryBulkDeleteFailedError, + SavedQueryNotFoundError, +) +from superset.queries.saved_queries.filters import SavedQueryFilter +from superset.queries.saved_queries.schemas import ( + get_delete_ids_schema, + openapi_spec_methods_override, +) +from superset.views.base_api import BaseSupersetModelRestApi, statsd_metrics + +logger = logging.getLogger(__name__) + + +class SavedQueryRestApi(BaseSupersetModelRestApi): + datamodel = SQLAInterface(SavedQuery) + + include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | { + RouteMethod.RELATED, + RouteMethod.DISTINCT, + "bulk_delete", # not using RouteMethod since locally defined + } + class_permission_name = "SavedQueryView" + resource_name = "saved_query" + allow_browser_login = True + + base_filters = [["id", SavedQueryFilter, lambda: []]] + + show_columns = [ + "created_by.first_name", + "created_by.id", + "created_by.last_name", + "database.database_name", + "database.id", + "description", + "id", + "label", + "schema", + "sql", + "sql_tables", + ] + list_columns = [ + "created_by.first_name", + "created_by.id", + "created_by.last_name", + "database.database_name", + "database.id", + "db_id", + "description", + "label", + "schema", + "sql", + "sql_tables", + ] + add_columns = ["db_id", "description", "label", "schema", "sql"] + edit_columns = add_columns + order_columns = [ + "schema", + "label", + "description", + "sql", + "created_by.first_name", + "database.database_name", + ] + + apispec_parameter_schemas = { + "get_delete_ids_schema": get_delete_ids_schema, + } + openapi_spec_tag = "Queries" + openapi_spec_methods = openapi_spec_methods_override + + related_field_filters = { + "database": "database_name", + } + filter_rel_fields = {"database": [["id", DatabaseFilter, lambda: []]]} + allowed_rel_fields = {"database"} + allowed_distinct_fields = {"schema"} + + def pre_add(self, item: SavedQuery) -> None: + item.user = g.user + + def pre_update(self, item: SavedQuery) -> None: + self.pre_add(item) + + @expose("/", methods=["DELETE"]) + @protect() + @safe + @statsd_metrics + @rison(get_delete_ids_schema) + def bulk_delete( + self, **kwargs: Any + ) -> Response: # pylint: disable=arguments-differ + """Delete bulk Saved Queries + --- + delete: + description: >- + Deletes multiple saved queries in a bulk operation. + parameters: + - in: query + name: q + content: + application/json: + schema: + $ref: '#/components/schemas/get_delete_ids_schema' + responses: + 200: + description: Saved queries bulk delete + content: + application/json: + schema: + type: object + properties: + message: + type: string + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + item_ids = kwargs["rison"] + try: + BulkDeleteSavedQueryCommand(g.user, item_ids).run() + return self.response( + 200, + message=ngettext( + "Deleted %(num)d saved query", + "Deleted %(num)d saved queries", + num=len(item_ids), + ), + ) + except SavedQueryNotFoundError: + return self.response_404() + except SavedQueryBulkDeleteFailedError as ex: + return self.response_422(message=str(ex)) diff --git a/superset/queries/saved_queries/commands/__init__.py b/superset/queries/saved_queries/commands/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/superset/queries/saved_queries/commands/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/queries/saved_queries/commands/bulk_delete.py b/superset/queries/saved_queries/commands/bulk_delete.py new file mode 100644 index 0000000000000..cf021442fbfcc --- /dev/null +++ b/superset/queries/saved_queries/commands/bulk_delete.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import List, Optional + +from flask_appbuilder.security.sqla.models import User + +from superset.commands.base import BaseCommand +from superset.dao.exceptions import DAODeleteFailedError +from superset.models.dashboard import Dashboard +from superset.queries.saved_queries.commands.exceptions import ( + SavedQueryBulkDeleteFailedError, + SavedQueryNotFoundError, +) +from superset.queries.saved_queries.dao import SavedQueryDAO + +logger = logging.getLogger(__name__) + + +class BulkDeleteSavedQueryCommand(BaseCommand): + def __init__(self, user: User, model_ids: List[int]): + self._actor = user + self._model_ids = model_ids + self._models: Optional[List[Dashboard]] = None + + def run(self) -> None: + self.validate() + try: + SavedQueryDAO.bulk_delete(self._models) + return None + except DAODeleteFailedError as ex: + logger.exception(ex.exception) + raise SavedQueryBulkDeleteFailedError() + + def validate(self) -> None: + # Validate/populate model exists + self._models = SavedQueryDAO.find_by_ids(self._model_ids) + if not self._models or len(self._models) != len(self._model_ids): + raise SavedQueryNotFoundError() diff --git a/superset/queries/saved_queries/commands/exceptions.py b/superset/queries/saved_queries/commands/exceptions.py new file mode 100644 index 0000000000000..0e03dc7f4dced --- /dev/null +++ b/superset/queries/saved_queries/commands/exceptions.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from flask_babel import lazy_gettext as _ + +from superset.commands.exceptions import CommandException, DeleteFailedError + + +class SavedQueryBulkDeleteFailedError(DeleteFailedError): + message = _("Saved queries could not be deleted.") + + +class SavedQueryNotFoundError(CommandException): + message = _("Saved query not found.") diff --git a/superset/queries/saved_queries/dao.py b/superset/queries/saved_queries/dao.py new file mode 100644 index 0000000000000..cd20fe60de583 --- /dev/null +++ b/superset/queries/saved_queries/dao.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import List, Optional + +from sqlalchemy.exc import SQLAlchemyError + +from superset.dao.base import BaseDAO +from superset.dao.exceptions import DAODeleteFailedError +from superset.extensions import db +from superset.models.sql_lab import SavedQuery +from superset.queries.saved_queries.filters import SavedQueryFilter + +logger = logging.getLogger(__name__) + + +class SavedQueryDAO(BaseDAO): + model_cls = SavedQuery + base_filter = SavedQueryFilter + + @staticmethod + def bulk_delete(models: Optional[List[SavedQuery]], commit: bool = True) -> None: + item_ids = [model.id for model in models] if models else [] + try: + db.session.query(SavedQuery).filter(SavedQuery.id.in_(item_ids)).delete( + synchronize_session="fetch" + ) + if commit: + db.session.commit() + except SQLAlchemyError: + if commit: + db.session.rollback() + raise DAODeleteFailedError() diff --git a/superset/queries/saved_queries/filters.py b/superset/queries/saved_queries/filters.py new file mode 100644 index 0000000000000..498a061edce10 --- /dev/null +++ b/superset/queries/saved_queries/filters.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any + +from flask import g +from flask_sqlalchemy import BaseQuery + +from superset.models.sql_lab import SavedQuery +from superset.views.base import BaseFilter + + +class SavedQueryFilter(BaseFilter): # pylint: disable=too-few-public-methods + def apply(self, query: BaseQuery, value: Any) -> BaseQuery: + """ + Filter saved queries to only those created by current user. + + :returns: flask-sqlalchemy query + """ + return query.filter( + SavedQuery.created_by == g.user # pylint: disable=comparison-with-callable + ) diff --git a/superset/queries/saved_queries/schemas.py b/superset/queries/saved_queries/schemas.py new file mode 100644 index 0000000000000..d875e1a12bb6b --- /dev/null +++ b/superset/queries/saved_queries/schemas.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +openapi_spec_methods_override = { + "get": {"get": {"description": "Get a saved query",}}, + "get_list": { + "get": { + "description": "Get a list of saved queries, use Rison or JSON " + "query parameters for filtering, sorting," + " pagination and for selecting specific" + " columns and metadata.", + } + }, + "post": {"post": {"description": "Create a saved query"}}, + "put": {"put": {"description": "Update a saved query"}}, + "delete": {"delete": {"description": "Delete saved query"}}, +} + +get_delete_ids_schema = {"type": "array", "items": {"type": "integer"}} diff --git a/superset/queries/savedqueries/api.py b/superset/queries/savedqueries/api.py deleted file mode 100644 index 0b62d37410c87..0000000000000 --- a/superset/queries/savedqueries/api.py +++ /dev/null @@ -1,103 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import logging - -from flask_appbuilder.models.sqla.interface import SQLAInterface - -from superset.constants import RouteMethod -from superset.databases.filters import DatabaseFilter -from superset.models.sql_lab import SavedQuery -from superset.views.base_api import BaseSupersetModelRestApi - -logger = logging.getLogger(__name__) - - -class SavedQueryRestApi(BaseSupersetModelRestApi): - datamodel = SQLAInterface(SavedQuery) - - include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | { - RouteMethod.RELATED, - RouteMethod.DISTINCT, - } - class_permission_name = "SavedQueryView" - resource_name = "saved_query" - allow_browser_login = True - show_columns = [ - "id", - "schema", - "label", - "description", - "sql", - "user.first_name", - "user.last_name", - "user.id", - "database.database_name", - "database.id", - ] - list_columns = [ - "user_id", - "db_id", - "schema", - "label", - "description", - "sql", - "user.first_name", - "user.last_name", - "user.id", - "database.database_name", - "database.id", - ] - add_columns = [ - "schema", - "label", - "description", - "sql", - "user_id", - "db_id", - ] - edit_columns = add_columns - order_columns = [ - "schema", - "label", - "description", - "sql", - "user.first_name", - "database.database_name", - ] - - openapi_spec_tag = "Queries" - openapi_spec_methods = { - "get": {"get": {"description": "Get a saved query",}}, - "get_list": { - "get": { - "description": "Get a list of saved queries, use Rison or JSON " - "query parameters for filtering, sorting," - " pagination and for selecting specific" - " columns and metadata.", - } - }, - "post": {"post": {"description": "Create a saved query"}}, - "put": {"put": {"description": "Update a saved query"}}, - "delete": {"delete": {"description": "Delete saved query"}}, - } - - related_field_filters = { - "database": "database_name", - } - filter_rel_fields = {"database": [["id", DatabaseFilter, lambda: []]]} - allowed_rel_fields = {"database"} - allowed_distinct_fields = {"schema"} diff --git a/tests/queries/saved_queries/api_tests.py b/tests/queries/saved_queries/api_tests.py index 73d373089a42c..b3ce625b1fd6b 100644 --- a/tests/queries/saved_queries/api_tests.py +++ b/tests/queries/saved_queries/api_tests.py @@ -32,52 +32,63 @@ from tests.base_tests import SupersetTestCase +SAVED_QUERIES_FIXTURE_COUNT = 5 + + class TestSavedQueryApi(SupersetTestCase): def insert_saved_query( self, label: str, sql: str, db_id: Optional[int] = None, - user_id: Optional[int] = None, + created_by=None, schema: Optional[str] = "", ) -> SavedQuery: database = None - user = None if db_id: database = db.session.query(Database).get(db_id) - if user_id: - user = db.session.query(security_manager.user_model).get(user_id) query = SavedQuery( - database=database, user=user, sql=sql, label=label, schema=schema + database=database, + created_by=created_by, + sql=sql, + label=label, + schema=schema, ) db.session.add(query) db.session.commit() return query def insert_default_saved_query( - self, label: str = "saved1", schema: str = "schema1", + self, label: str = "saved1", schema: str = "schema1", username: str = "admin" ) -> SavedQuery: - admin = self.get_user("admin") + admin = self.get_user(username) example_db = get_example_database() return self.insert_saved_query( label, "SELECT col1, col2 from table1", db_id=example_db.id, - user_id=admin.id, + created_by=admin, schema=schema, ) @pytest.fixture() def create_saved_queries(self): with self.create_app().app_context(): - num_saved_queries = 5 saved_queries = [] - for cx in range(num_saved_queries): + for cx in range(SAVED_QUERIES_FIXTURE_COUNT - 1): saved_queries.append( self.insert_default_saved_query( label=f"label{cx}", schema=f"schema{cx}" ) ) + saved_queries.append( + self.insert_default_saved_query( + label=f"label{SAVED_QUERIES_FIXTURE_COUNT}", + schema=f"schema{SAVED_QUERIES_FIXTURE_COUNT}", + username="gamma", + ) + ) + yield saved_queries # rollback changes @@ -90,34 +101,55 @@ def test_get_list_saved_query(self): """ Saved Query API: Test get list saved query """ - queries = db.session.query(SavedQuery).all() + admin = self.get_user("admin") + saved_queries = ( + db.session.query(SavedQuery).filter(SavedQuery.created_by == admin).all() + ) self.login(username="admin") uri = f"api/v1/saved_query/" rv = self.get_assert_metric(uri, "get_list") assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - assert data["count"] == len(queries) + assert data["count"] == len(saved_queries) expected_columns = [ - "user_id", + "created_by", + "database", "db_id", - "schema", - "label", "description", + "label", + "schema", "sql", - "user", - "database", + "sql_tables", ] for expected_column in expected_columns: assert expected_column in data["result"][0] + @pytest.mark.usefixtures("create_saved_queries") + def test_get_list_saved_query_gamma(self): + """ + Saved Query API: Test get list saved query + """ + gamma = self.get_user("gamma") + saved_queries = ( + db.session.query(SavedQuery).filter(SavedQuery.created_by == gamma).all() + ) + + self.login(username="gamma") + uri = f"api/v1/saved_query/" + rv = self.get_assert_metric(uri, "get_list") + assert rv.status_code == 200 + data = json.loads(rv.data.decode("utf-8")) + assert data["count"] == len(saved_queries) + @pytest.mark.usefixtures("create_saved_queries") def test_get_list_sort_saved_query(self): """ Saved Query API: Test get list and sort saved query """ - all_queries = ( - db.session.query(SavedQuery).order_by(asc(SavedQuery.schema)).all() + admin = self.get_user("admin") + saved_queries = ( + db.session.query(SavedQuery).filter(SavedQuery.created_by == admin).all() ) self.login(username="admin") query_string = {"order_column": "schema", "order_direction": "asc"} @@ -125,8 +157,8 @@ def test_get_list_sort_saved_query(self): rv = self.get_assert_metric(uri, "get_list") assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - assert data["count"] == len(all_queries) - for i, query in enumerate(all_queries): + assert data["count"] == len(saved_queries) + for i, query in enumerate(saved_queries): assert query.schema == data["result"][i]["schema"] query_string = { @@ -137,7 +169,10 @@ def test_get_list_sort_saved_query(self): rv = self.get_assert_metric(uri, "get_list") assert rv.status_code == 200 - query_string = {"order_column": "user.first_name", "order_direction": "asc"} + query_string = { + "order_column": "created_by.first_name", + "order_direction": "asc", + } uri = f"api/v1/saved_query/?q={prison.dumps(query_string)}" rv = self.get_assert_metric(uri, "get_list") assert rv.status_code == 200 @@ -202,14 +237,22 @@ def test_distinct_saved_query(self): """ SavedQuery API: Test distinct schemas """ + admin = self.get_user("admin") + saved_queries = ( + db.session.query(SavedQuery).filter(SavedQuery.created_by == admin).all() + ) + self.login(username="admin") uri = f"api/v1/saved_query/distinct/schema" rv = self.client.get(uri) assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) expected_response = { - "count": 5, - "result": [{"text": f"schema{i}", "value": f"schema{i}"} for i in range(5)], + "count": len(saved_queries), + "result": [ + {"text": f"schema{i}", "value": f"schema{i}"} + for i in range(len(saved_queries)) + ], } assert data == expected_response @@ -227,20 +270,25 @@ def test_get_saved_query(self): """ Saved Query API: Test get saved query """ - query = ( + saved_query = ( db.session.query(SavedQuery).filter(SavedQuery.label == "label1").all()[0] ) self.login(username="admin") - uri = f"api/v1/saved_query/{query.id}" + uri = f"api/v1/saved_query/{saved_query.id}" rv = self.get_assert_metric(uri, "get") assert rv.status_code == 200 expected_result = { - "id": query.id, - "database": {"id": query.database.id, "database_name": "examples"}, + "id": saved_query.id, + "database": {"id": saved_query.database.id, "database_name": "examples"}, "description": None, - "user": {"first_name": "admin", "id": query.user_id, "last_name": "user"}, + "created_by": { + "first_name": saved_query.created_by.first_name, + "id": saved_query.created_by.id, + "last_name": saved_query.created_by.last_name, + }, "sql": "SELECT col1, col2 from table1", + "sql_tables": [{"catalog": None, "schema": None, "table": "table1"}], "schema": "schema1", "label": "label1", } @@ -271,7 +319,6 @@ def test_create_saved_query(self): "label": "label1", "description": "some description", "sql": "SELECT col1, col2 from table1", - "user_id": admin.id, "db_id": example_db.id, } @@ -357,3 +404,67 @@ def test_delete_saved_query_not_found(self): uri = f"api/v1/saved_query/{max_id + 1}" rv = self.client.delete(uri) assert rv.status_code == 404 + + @pytest.mark.usefixtures("create_saved_queries") + def test_delete_bulk_saved_queries(self): + """ + Saved Query API: Test delete bulk + """ + admin = self.get_user("admin") + saved_queries = ( + db.session.query(SavedQuery).filter(SavedQuery.created_by == admin).all() + ) + saved_query_ids = [saved_query.id for saved_query in saved_queries] + + self.login(username="admin") + uri = f"api/v1/saved_query/?q={prison.dumps(saved_query_ids)}" + rv = self.delete_assert_metric(uri, "bulk_delete") + assert rv.status_code == 200 + response = json.loads(rv.data.decode("utf-8")) + expected_response = {"message": f"Deleted {len(saved_query_ids)} saved queries"} + assert response == expected_response + saved_queries = ( + db.session.query(SavedQuery).filter(SavedQuery.created_by == admin).all() + ) + assert saved_queries == [] + + @pytest.mark.usefixtures("create_saved_queries") + def test_delete_one_bulk_saved_queries(self): + """ + Saved Query API: Test delete one in bulk + """ + saved_query = db.session.query(SavedQuery).first() + saved_query_ids = [saved_query.id] + + self.login(username="admin") + uri = f"api/v1/saved_query/?q={prison.dumps(saved_query_ids)}" + rv = self.delete_assert_metric(uri, "bulk_delete") + assert rv.status_code == 200 + response = json.loads(rv.data.decode("utf-8")) + expected_response = {"message": f"Deleted {len(saved_query_ids)} saved query"} + assert response == expected_response + saved_query_ = db.session.query(SavedQuery).get(saved_query_ids[0]) + assert saved_query_ is None + + def test_delete_bulk_saved_query_bad_request(self): + """ + Saved Query API: Test delete bulk bad request + """ + saved_query_ids = [1, "a"] + self.login(username="admin") + uri = f"api/v1/saved_query/?q={prison.dumps(saved_query_ids)}" + rv = self.delete_assert_metric(uri, "bulk_delete") + assert rv.status_code == 400 + + @pytest.mark.usefixtures("create_saved_queries") + def test_delete_bulk_saved_query_not_found(self): + """ + Saved Query API: Test delete bulk not found + """ + max_id = db.session.query(func.max(SavedQuery.id)).scalar() + + saved_query_ids = [max_id + 1, max_id + 2] + self.login(username="admin") + uri = f"api/v1/saved_query/?q={prison.dumps(saved_query_ids)}" + rv = self.delete_assert_metric(uri, "bulk_delete") + assert rv.status_code == 404