From ef8436b783a2baac74322ff2b98c0189edc332e6 Mon Sep 17 00:00:00 2001 From: Peter Kosztolanyi Date: Fri, 9 Jul 2021 04:23:53 +0100 Subject: [PATCH] Better types, docstring and naming --- superset/db_engine_specs/base.py | 18 +++++++++------- superset/db_engine_specs/mysql.py | 21 ++++++++++++++++--- superset/db_engine_specs/postgres.py | 21 ++++++++++++++++--- superset/db_engine_specs/snowflake.py | 21 ++++++++++++++++--- superset/sql_lab.py | 14 ++++++------- .../db_engine_specs/mysql_tests.py | 4 ++-- .../db_engine_specs/postgres_tests.py | 11 ++++------ .../db_engine_specs/snowflake_tests.py | 6 +++--- 8 files changed, 80 insertions(+), 36 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 8fef471b38835..0634c8470c492 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1305,24 +1305,26 @@ def get_column_spec( return None @classmethod - def get_cancel_query_payload(cls, cursor: Any, query: Query) -> Any: + def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]: """ - Returns None if query can not be cancelled. + Select identifiers from the database engine that uniquely identifies the + queries to cancel. The identifier is typically a session id, process id + or similar. + :param cursor: Cursor instance in which the query will be executed :param query: Query instance - :return: Type of the payload can vary depends on databases - but must be jsonable. None if query can't be cancelled. + :return: Query identifier """ return None @classmethod - def cancel_query(cls, cursor: Any, query: Query, payload: Any) -> None: + def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> None: """ - Cancels query in the underlying database. - The method is called only when payload is not None. + Cancel query in the underlying database. + :param cursor: New cursor instance to the db of the query :param query: Query instance - :param payload: Value returned by get_cancel_query_payload or set in + :param cancel_query_id: Value returned by get_cancel_query_payload or set in other life-cycle methods of the query """ diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index 01f41fe8b38c4..896edfa6f7150 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -223,11 +223,26 @@ def get_column_spec( # type: ignore ) @classmethod - def get_cancel_query_payload(cls, cursor: Any, query: Query) -> Any: + def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]: + """ + Get MySQL connection ID that will be used to cancel all other running + queries in the same connection. + + :param cursor: Cursor instance in which the query will be executed + :param query: Query instance + :return: MySQL Connection ID + """ cursor.execute("SELECT CONNECTION_ID()") row = cursor.fetchone() return row[0] @classmethod - def cancel_query(cls, cursor: Any, query: Query, payload: Any) -> None: - cursor.execute("KILL CONNECTION %d" % payload) + def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> None: + """ + Cancel query in the underlying database. + + :param cursor: New cursor instance to the db of the query + :param query: Query instance + :param cancel_query_id: MySQL Connection ID + """ + cursor.execute("KILL CONNECTION %s" % cancel_query_id) diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 1871fd6bc38c7..903d898bb938b 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -299,15 +299,30 @@ def get_column_spec( # type: ignore ) @classmethod - def get_cancel_query_payload(cls, cursor: Any, query: Query) -> Any: + def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]: + """ + Get Postgres PID that will be used to cancel all other running + queries in the same session. + + :param cursor: Cursor instance in which the query will be executed + :param query: Query instance + :return: Postgres PID + """ cursor.execute("SELECT pg_backend_pid()") row = cursor.fetchone() return row[0] @classmethod - def cancel_query(cls, cursor: Any, query: Query, payload: Any) -> None: + def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> None: + """ + Cancel query in the underlying database. + + :param cursor: New cursor instance to the db of the query + :param query: Query instance + :param cancel_query_id: Postgres PID + """ cursor.execute( "SELECT pg_terminate_backend(pid) " "FROM pg_stat_activity " - "WHERE pid='%s'" % payload + "WHERE pid='%s'" % cancel_query_id ) diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index 2547171916ecd..a12da85939f09 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -131,11 +131,26 @@ def mutate_db_for_connection_test(database: "Database") -> None: database.extra = json.dumps(extra) @classmethod - def get_cancel_query_payload(cls, cursor: Any, query: Query) -> Any: + def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]: + """ + Get Snowflake session ID that will be used to cancel all other running + queries in the same session. + + :param cursor: Cursor instance in which the query will be executed + :param query: Query instance + :return: Snowflake Session ID + """ cursor.execute("SELECT CURRENT_SESSION()") row = cursor.fetchone() return row[0] @classmethod - def cancel_query(cls, cursor: Any, query: Query, payload: Any) -> None: - cursor.execute("SELECT SYSTEM$CANCEL_ALL_QUERIES(%s)" % payload) + def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> None: + """ + Cancel query in the underlying database. + + :param cursor: New cursor instance to the db of the query + :param query: Query instance + :param cancel_query_id: Snowflake Session ID + """ + cursor.execute("SELECT SYSTEM$CANCEL_ALL_QUERIES(%s)" % cancel_query_id) diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 339f666fa14b3..32c8fe706c00e 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -73,7 +73,7 @@ def dummy_sql_query_mutator( SQL_QUERY_MUTATOR = config.get("SQL_QUERY_MUTATOR") or dummy_sql_query_mutator log_query = config["QUERY_LOGGER"] logger = logging.getLogger(__name__) -cancel_payload_key = "cancel_payload" +cancel_query_key = "cancel_query" class SqlLabException(Exception): @@ -449,9 +449,9 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca with closing(engine.raw_connection()) as conn: # closing the connection closes the cursor as well cursor = conn.cursor() - cancel_query_payload = db_engine_spec.get_cancel_query_payload(cursor, query) - if cancel_query_payload is not None: - query.set_extra_json_key(cancel_payload_key, cancel_query_payload) + cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query) + if cancel_query_id is not None: + query.set_extra_json_key(cancel_query_key, cancel_query_id) session.commit() statement_count = len(statements) for i, statement in enumerate(statements): @@ -591,8 +591,8 @@ def cancel_query(query: Query, user_name: Optional[str] = None) -> None: :param user_name: Default username :return: None """ - cancel_payload = query.extra.get(cancel_payload_key, None) - if cancel_payload is None: + cancel_query_id = query.extra.get(cancel_query_key, None) + if cancel_query_id is None: return database = query.database @@ -606,4 +606,4 @@ def cancel_query(query: Query, user_name: Optional[str] = None) -> None: with closing(engine.raw_connection()) as conn: with closing(conn.cursor()) as cursor: - db_engine_spec.cancel_query(cursor, query, cancel_payload) + db_engine_spec.cancel_query(cursor, query, cancel_query_id) diff --git a/tests/integration_tests/db_engine_specs/mysql_tests.py b/tests/integration_tests/db_engine_specs/mysql_tests.py index 6fc29078a7fe7..b5c2e53b996e7 100644 --- a/tests/integration_tests/db_engine_specs/mysql_tests.py +++ b/tests/integration_tests/db_engine_specs/mysql_tests.py @@ -241,11 +241,11 @@ def test_extract_errors(self): ] @unittest.mock.patch("sqlalchemy.engine.Engine.connect") - def test_get_cancel_query_payload(self, engine_mock): + def test_get_cancel_query_id(self, engine_mock): query = Query() cursor_mock = engine_mock.return_value.__enter__.return_value cursor_mock.fetchone.return_value = [123] - assert MySQLEngineSpec.get_cancel_query_payload(cursor_mock, query) == 123 + assert MySQLEngineSpec.get_cancel_query_id(cursor_mock, query) == 123 @unittest.mock.patch("sqlalchemy.engine.Engine.connect") def test_cancel_query(self, engine_mock): diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py index da634121cc06c..874ab80d9792a 100644 --- a/tests/integration_tests/db_engine_specs/postgres_tests.py +++ b/tests/integration_tests/db_engine_specs/postgres_tests.py @@ -445,20 +445,17 @@ def test_extract_errors(self): ] @mock.patch("sqlalchemy.engine.Engine.connect") - def test_get_cancel_query_payload(self, engine_mock): + def test_get_cancel_query_id(self, engine_mock): query = Query() cursor_mock = engine_mock.return_value.__enter__.return_value - cursor_mock.fetchone.return_value = ["testuser"] - assert ( - PostgresEngineSpec.get_cancel_query_payload(cursor_mock, query) - == "testuser" - ) + cursor_mock.fetchone.return_value = [123] + assert PostgresEngineSpec.get_cancel_query_id(cursor_mock, query) == 123 @mock.patch("sqlalchemy.engine.Engine.connect") def test_cancel_query(self, engine_mock): query = Query() cursor_mock = engine_mock.return_value.__enter__.return_value - assert PostgresEngineSpec.cancel_query(cursor_mock, query, "testuser") is None + assert PostgresEngineSpec.cancel_query(cursor_mock, query, 123) is None def test_base_parameters_mixin(): diff --git a/tests/integration_tests/db_engine_specs/snowflake_tests.py b/tests/integration_tests/db_engine_specs/snowflake_tests.py index c4792c1279a1f..36cdb1e4c5596 100644 --- a/tests/integration_tests/db_engine_specs/snowflake_tests.py +++ b/tests/integration_tests/db_engine_specs/snowflake_tests.py @@ -103,14 +103,14 @@ def test_extract_errors(self): ] @mock.patch("sqlalchemy.engine.Engine.connect") - def test_get_cancel_query_payload(self, engine_mock): + def test_get_cancel_query_id(self, engine_mock): query = Query() cursor_mock = engine_mock.return_value.__enter__.return_value cursor_mock.fetchone.return_value = [123] - assert SnowflakeEngineSpec.get_cancel_query_payload(cursor_mock, query) == 123 + assert SnowflakeEngineSpec.get_cancel_query_id(cursor_mock, query) == 123 @mock.patch("sqlalchemy.engine.Engine.connect") def test_cancel_query(self, engine_mock): query = Query() cursor_mock = engine_mock.return_value.__enter__.return_value - assert SnowflakeEngineSpec.cancel_query(cursor_mock, query, None) is None + assert SnowflakeEngineSpec.cancel_query(cursor_mock, query, 123) is None