diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 08c9a71f4bc0a..2a23d1c969593 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import logging from typing import Any, Dict, List, Optional, TYPE_CHECKING @@ -90,7 +92,7 @@ def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool: @classmethod def get_table_names( cls, - database: "Database", + database: Database, inspector: Inspector, schema: Optional[str], ) -> List[str]: @@ -103,7 +105,7 @@ def get_table_names( @classmethod def get_view_names( cls, - database: "Database", + database: Database, inspector: Inspector, schema: Optional[str], ) -> List[str]: @@ -114,7 +116,7 @@ def get_view_names( ) @classmethod - def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]: + def get_tracking_url(cls, cursor: Cursor) -> Optional[str]: try: return cursor.info_uri except AttributeError: @@ -127,14 +129,42 @@ def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]: return None @classmethod - def handle_cursor(cls, cursor: "Cursor", query: Query, session: Session) -> None: - """Updates progress information""" + def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None: tracking_url = cls.get_tracking_url(cursor) if tracking_url: query.tracking_url = tracking_url - session.commit() + + # Adds the executed query id to the extra payload so the query can be cancelled + query.set_extra_json_key("cancel_query", cursor.stats["queryId"]) + + session.commit() BaseEngineSpec.handle_cursor(cursor=cursor, query=query, session=session) + @classmethod + def has_implicit_cancel(cls) -> bool: + return False + + @classmethod + def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool: + """ + Cancel query in the underlying database. + + :param cursor: New cursor instance to the db of the query + :param query: Query instance + :param cancel_query_id: Trino `queryId` + :return: True if query cancelled successfully, False otherwise + """ + try: + cursor.execute( + f"CALL system.runtime.kill_query(query_id => '{cancel_query_id}'," + "message => 'Query cancelled by Superset')" + ) + cursor.fetchall() # needed to trigger the call + except Exception: # pylint: disable=broad-except + return False + + return True + @staticmethod def get_extra_params(database: "Database") -> Dict[str, Any]: """ diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py new file mode 100644 index 0000000000000..6a77e63236091 --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument, import-outside-toplevel, protected-access +from unittest import mock + + +@mock.patch("sqlalchemy.engine.Engine.connect") +def test_cancel_query_success(engine_mock: mock.Mock) -> None: + from superset.db_engine_specs.trino import TrinoEngineSpec + from superset.models.sql_lab import Query + + query = Query() + cursor_mock = engine_mock.return_value.__enter__.return_value + assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is True + + +@mock.patch("sqlalchemy.engine.Engine.connect") +def test_cancel_query_failed(engine_mock: mock.Mock) -> None: + from superset.db_engine_specs.trino import TrinoEngineSpec + from superset.models.sql_lab import Query + + query = Query() + cursor_mock = engine_mock.raiseError.side_effect = Exception() + assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is False