diff --git a/superset-frontend/src/SqlLab/components/SqlEditor.jsx b/superset-frontend/src/SqlLab/components/SqlEditor.jsx index 3883c43b2f916..540cfb77f3f41 100644 --- a/superset-frontend/src/SqlLab/components/SqlEditor.jsx +++ b/superset-frontend/src/SqlLab/components/SqlEditor.jsx @@ -194,6 +194,7 @@ class SqlEditor extends React.PureComponent { WINDOW_RESIZE_THROTTLE_MS, ); + this.onBeforeUnload = this.onBeforeUnload.bind(this); this.renderDropdown = this.renderDropdown.bind(this); } @@ -212,6 +213,7 @@ class SqlEditor extends React.PureComponent { this.setState({ height: this.getSqlEditorHeight() }); window.addEventListener('resize', this.handleWindowResize); + window.addEventListener('beforeunload', this.onBeforeUnload); // setup hotkeys const hotkeys = this.getHotkeyConfig(); @@ -222,6 +224,7 @@ class SqlEditor extends React.PureComponent { componentWillUnmount() { window.removeEventListener('resize', this.handleWindowResize); + window.removeEventListener('beforeunload', this.onBeforeUnload); } onResizeStart() { @@ -242,6 +245,16 @@ class SqlEditor extends React.PureComponent { } } + onBeforeUnload(event) { + if ( + this.props.database?.extra_json?.cancel_query_on_windows_unload && + this.props.latestQuery?.state === 'running' + ) { + event.preventDefault(); + this.stopQuery(); + } + } + onSqlChanged(sql) { this.setState({ sql }); this.setQueryEditorSqlWithDebounce(sql); diff --git a/superset-frontend/src/SqlLab/reducers/sqlLab.js b/superset-frontend/src/SqlLab/reducers/sqlLab.js index daa06a97c88da..9e423ba7c65b2 100644 --- a/superset-frontend/src/SqlLab/reducers/sqlLab.js +++ b/superset-frontend/src/SqlLab/reducers/sqlLab.js @@ -499,7 +499,10 @@ export default function sqlLabReducer(state = {}, action) { [actions.SET_DATABASES]() { const databases = {}; action.databases.forEach(db => { - databases[db.id] = db; + databases[db.id] = { + ...db, + extra_json: JSON.parse(db.extra || ''), + }; }); return { ...state, databases }; }, diff --git a/superset-frontend/src/views/CRUD/data/database/DatabaseModal/ExtraOptions.tsx b/superset-frontend/src/views/CRUD/data/database/DatabaseModal/ExtraOptions.tsx index a2a49b2f707b9..80b4044202202 100644 --- a/superset-frontend/src/views/CRUD/data/database/DatabaseModal/ExtraOptions.tsx +++ b/superset-frontend/src/views/CRUD/data/database/DatabaseModal/ExtraOptions.tsx @@ -294,6 +294,24 @@ const ExtraOptions = ({ /> + +
+ + +
+
Optional[str]: + """ + Select identifiers from the database engine that uniquely identifies the + queries to cancel. The identifier is typically a session id, process id + or similar. + + :param cursor: Cursor instance in which the query will be executed + :param query: Query instance + :return: Query identifier + """ + return None + + @classmethod + def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool: + """ + Cancel query in the underlying database. + + :param cursor: New cursor instance to the db of the query + :param query: Query instance + :param cancel_query_id: Value returned by get_cancel_query_payload or set in + other life-cycle methods of the query + :return: True if query cancelled successfully, False otherwise + """ + # schema for adding a database by providing parameters instead of the # full SQLAlchemy URI diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index 01be0c6e13d0a..a48678ef71c79 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -37,6 +37,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin from superset.errors import SupersetErrorType +from superset.models.sql_lab import Query from superset.utils import core as utils from superset.utils.core import ColumnSpec, GenericDataType @@ -220,3 +221,34 @@ def get_column_spec( # type: ignore return super().get_column_spec( native_type, column_type_mappings=column_type_mappings ) + + @classmethod + def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]: + """ + Get MySQL connection ID that will be used to cancel all other running + queries in the same connection. + + :param cursor: Cursor instance in which the query will be executed + :param query: Query instance + :return: MySQL Connection ID + """ + cursor.execute("SELECT CONNECTION_ID()") + row = cursor.fetchone() + return row[0] + + @classmethod + def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool: + """ + Cancel query in the underlying database. + + :param cursor: New cursor instance to the db of the query + :param query: Query instance + :param cancel_query_id: MySQL Connection ID + :return: True if query cancelled successfully, False otherwise + """ + try: + cursor.execute(f"KILL CONNECTION {cancel_query_id}") + except Exception: # pylint: disable=broad-except + return False + + return True diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index ee95d5dd3d50d..fa8809e151014 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -40,6 +40,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin from superset.errors import SupersetErrorType from superset.exceptions import SupersetException +from superset.models.sql_lab import Query from superset.utils import core as utils from superset.utils.core import ColumnSpec, GenericDataType @@ -296,3 +297,38 @@ def get_column_spec( # type: ignore return super().get_column_spec( native_type, column_type_mappings=column_type_mappings ) + + @classmethod + def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]: + """ + Get Postgres PID that will be used to cancel all other running + queries in the same session. + + :param cursor: Cursor instance in which the query will be executed + :param query: Query instance + :return: Postgres PID + """ + cursor.execute("SELECT pg_backend_pid()") + row = cursor.fetchone() + return row[0] + + @classmethod + def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool: + """ + Cancel query in the underlying database. + + :param cursor: New cursor instance to the db of the query + :param query: Query instance + :param cancel_query_id: Postgres PID + :return: True if query cancelled successfully, False otherwise + """ + try: + cursor.execute( + "SELECT pg_terminate_backend(pid) " + "FROM pg_stat_activity " + f"WHERE pid='{cancel_query_id}'" + ) + except Exception: # pylint: disable=broad-except + return False + + return True diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index 11e1cd414f032..6dd85706562bd 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -25,6 +25,7 @@ from superset.db_engine_specs.postgres import PostgresBaseEngineSpec from superset.errors import SupersetErrorType +from superset.models.sql_lab import Query from superset.utils import core as utils if TYPE_CHECKING: @@ -128,3 +129,34 @@ def mutate_db_for_connection_test(database: "Database") -> None: engine_params["connect_args"] = connect_args extra["engine_params"] = engine_params database.extra = json.dumps(extra) + + @classmethod + def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]: + """ + Get Snowflake session ID that will be used to cancel all other running + queries in the same session. + + :param cursor: Cursor instance in which the query will be executed + :param query: Query instance + :return: Snowflake Session ID + """ + cursor.execute("SELECT CURRENT_SESSION()") + row = cursor.fetchone() + return row[0] + + @classmethod + def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool: + """ + Cancel query in the underlying database. + + :param cursor: New cursor instance to the db of the query + :param query: Query instance + :param cancel_query_id: Snowflake Session ID + :return: True if query cancelled successfully, False otherwise + """ + try: + cursor.execute(f"SELECT SYSTEM$CANCEL_ALL_QUERIES({cancel_query_id})") + except Exception: # pylint: disable=broad-except + return False + + return True diff --git a/superset/exceptions.py b/superset/exceptions.py index 865187f64dc85..89ca41fbfcac0 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -210,3 +210,7 @@ def __init__(self, error: ValidationError): extra={"messages": error.messages}, ) super().__init__(error) + + +class SupersetCancelQueryException(SupersetException): + pass diff --git a/superset/sql_lab.py b/superset/sql_lab.py index a40e132e8e3cc..1ebf10e599dcf 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -73,6 +73,7 @@ def dummy_sql_query_mutator( SQL_QUERY_MUTATOR = config.get("SQL_QUERY_MUTATOR") or dummy_sql_query_mutator log_query = config["QUERY_LOGGER"] logger = logging.getLogger(__name__) +cancel_query_key = "cancel_query" class SqlLabException(Exception): @@ -83,6 +84,10 @@ class SqlLabSecurityException(SqlLabException): pass +class SqlLabQueryStoppedException(SqlLabException): + pass + + def handle_query_error( ex: Exception, query: Query, @@ -187,7 +192,7 @@ def get_sql_results( # pylint: disable=too-many-arguments return handle_query_error(ex, query, session) -# pylint: disable=too-many-arguments, too-many-locals +# pylint: disable=too-many-arguments, too-many-locals, too-many-statements def execute_sql_statement( sql_statement: str, query: Query, @@ -288,6 +293,12 @@ def execute_sql_statement( ) ) except Exception as ex: + # query is stopped in another thread/worker + # stopping raises expected exceptions which we should skip + session.refresh(query) + if query.status == QueryStatus.STOPPED: + raise SqlLabQueryStoppedException() + logger.error("Query %d: %s", query.id, type(ex), exc_info=True) logger.debug("Query %d: %s", query.id, ex) raise SqlLabException(db_engine_spec.extract_error_message(ex)) @@ -438,12 +449,17 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca with closing(engine.raw_connection()) as conn: # closing the connection closes the cursor as well cursor = conn.cursor() + cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query) + if cancel_query_id is not None: + query.set_extra_json_key(cancel_query_key, cancel_query_id) + session.commit() statement_count = len(statements) for i, statement in enumerate(statements): # Check if stopped - query = get_query(query_id, session) + session.refresh(query) if query.status == QueryStatus.STOPPED: - return None + payload.update({"status": query.status}) + return payload # For CTAS we create the table only on the last statement apply_ctas = query.select_as_cta and ( @@ -466,6 +482,9 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca log_params, apply_ctas, ) + except SqlLabQueryStoppedException: + payload.update({"status": QueryStatus.STOPPED}) + return payload except Exception as ex: # pylint: disable=broad-except msg = str(ex) prefix_message = ( @@ -562,3 +581,29 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca return payload return None + + +def cancel_query(query: Query, user_name: Optional[str] = None) -> bool: + """ + Cancel a running query. + + :param query: Query to cancel + :param user_name: Default username + :return: True if query cancelled successfully, False otherwise + """ + cancel_query_id = query.extra.get(cancel_query_key, None) + if cancel_query_id is None: + return False + + database = query.database + engine = database.get_sqla_engine( + schema=query.schema, + nullpool=True, + user_name=user_name, + source=QuerySource.SQL_LAB, + ) + db_engine_spec = database.db_engine_spec + + with closing(engine.raw_connection()) as conn: + with closing(conn.cursor()) as cursor: + return db_engine_spec.cancel_query(cursor, query, cancel_query_id) diff --git a/superset/views/core.py b/superset/views/core.py index b04b93c6bf702..169cd373eea13 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -80,6 +80,7 @@ CertificateException, DatabaseNotFound, SerializationError, + SupersetCancelQueryException, SupersetErrorException, SupersetErrorsException, SupersetException, @@ -2335,6 +2336,10 @@ def stop_query(self) -> FlaskResponse: str(client_id), ) return self.json_response("OK") + + if not sql_lab.cancel_query(query, g.user.username if g.user else None): + raise SupersetCancelQueryException("Could not cancel query") + query.status = QueryStatus.STOPPED db.session.commit() diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 2c9fbb9f9cc1f..d0d3d94e04130 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -177,6 +177,7 @@ def test_get_items(self): "database_name", "explore_database_id", "expose_in_sqllab", + "extra", "force_ctas_schema", "id", ] diff --git a/tests/integration_tests/db_engine_specs/mysql_tests.py b/tests/integration_tests/db_engine_specs/mysql_tests.py index 857d769681dea..b069bba69047a 100644 --- a/tests/integration_tests/db_engine_specs/mysql_tests.py +++ b/tests/integration_tests/db_engine_specs/mysql_tests.py @@ -21,6 +21,7 @@ from superset.db_engine_specs.mysql import MySQLEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.models.sql_lab import Query from superset.utils.core import GenericDataType from tests.integration_tests.db_engine_specs.base_tests import ( assert_generic_types, @@ -238,3 +239,22 @@ def test_extract_errors(self): }, ) ] + + @unittest.mock.patch("sqlalchemy.engine.Engine.connect") + def test_get_cancel_query_id(self, engine_mock): + query = Query() + cursor_mock = engine_mock.return_value.__enter__.return_value + cursor_mock.fetchone.return_value = [123] + assert MySQLEngineSpec.get_cancel_query_id(cursor_mock, query) == 123 + + @unittest.mock.patch("sqlalchemy.engine.Engine.connect") + def test_cancel_query(self, engine_mock): + query = Query() + cursor_mock = engine_mock.return_value.__enter__.return_value + assert MySQLEngineSpec.cancel_query(cursor_mock, query, 123) is True + + @unittest.mock.patch("sqlalchemy.engine.Engine.connect") + def test_cancel_query_failed(self, engine_mock): + query = Query() + cursor_mock = engine_mock.raiseError.side_effect = Exception() + assert MySQLEngineSpec.cancel_query(cursor_mock, query, 123) is False diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py index 283398d4745d9..4d03c5810ab20 100644 --- a/tests/integration_tests/db_engine_specs/postgres_tests.py +++ b/tests/integration_tests/db_engine_specs/postgres_tests.py @@ -23,6 +23,7 @@ from superset.db_engine_specs import get_engine_specs from superset.db_engine_specs.postgres import PostgresEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.models.sql_lab import Query from superset.utils.core import GenericDataType from tests.integration_tests.db_engine_specs.base_tests import ( assert_generic_types, @@ -443,6 +444,25 @@ def test_extract_errors(self): ) ] + @mock.patch("sqlalchemy.engine.Engine.connect") + def test_get_cancel_query_id(self, engine_mock): + query = Query() + cursor_mock = engine_mock.return_value.__enter__.return_value + cursor_mock.fetchone.return_value = [123] + assert PostgresEngineSpec.get_cancel_query_id(cursor_mock, query) == 123 + + @mock.patch("sqlalchemy.engine.Engine.connect") + def test_cancel_query(self, engine_mock): + query = Query() + cursor_mock = engine_mock.return_value.__enter__.return_value + assert PostgresEngineSpec.cancel_query(cursor_mock, query, 123) is True + + @mock.patch("sqlalchemy.engine.Engine.connect") + def test_cancel_query_failed(self, engine_mock): + query = Query() + cursor_mock = engine_mock.raiseError.side_effect = Exception() + assert PostgresEngineSpec.cancel_query(cursor_mock, query, 123) is False + def test_base_parameters_mixin(): parameters = { diff --git a/tests/integration_tests/db_engine_specs/snowflake_tests.py b/tests/integration_tests/db_engine_specs/snowflake_tests.py index 75fe7fd92e3f1..2e74e0e68d16f 100644 --- a/tests/integration_tests/db_engine_specs/snowflake_tests.py +++ b/tests/integration_tests/db_engine_specs/snowflake_tests.py @@ -15,12 +15,14 @@ # specific language governing permissions and limitations # under the License. import json +from unittest import mock from sqlalchemy import column from superset.db_engine_specs.snowflake import SnowflakeEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.models.core import Database +from superset.models.sql_lab import Query from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec @@ -99,3 +101,22 @@ def test_extract_errors(self): }, ) ] + + @mock.patch("sqlalchemy.engine.Engine.connect") + def test_get_cancel_query_id(self, engine_mock): + query = Query() + cursor_mock = engine_mock.return_value.__enter__.return_value + cursor_mock.fetchone.return_value = [123] + assert SnowflakeEngineSpec.get_cancel_query_id(cursor_mock, query) == 123 + + @mock.patch("sqlalchemy.engine.Engine.connect") + def test_cancel_query(self, engine_mock): + query = Query() + cursor_mock = engine_mock.return_value.__enter__.return_value + assert SnowflakeEngineSpec.cancel_query(cursor_mock, query, 123) is True + + @mock.patch("sqlalchemy.engine.Engine.connect") + def test_cancel_query_failed(self, engine_mock): + query = Query() + cursor_mock = engine_mock.raiseError.side_effect = Exception() + assert SnowflakeEngineSpec.cancel_query(cursor_mock, query, 123) is False