Skip to content

Commit

Permalink
Better types, docstring and naming
Browse files Browse the repository at this point in the history
  • Loading branch information
koszti committed Jul 9, 2021
1 parent 4706b43 commit ef8436b
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 36 deletions.
18 changes: 10 additions & 8 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,24 +1305,26 @@ def get_column_spec(
return None

@classmethod
def get_cancel_query_payload(cls, cursor: Any, query: Query) -> Any:
def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]:
"""
Returns None if query can not be cancelled.
Select identifiers from the database engine that uniquely identifies the
queries to cancel. The identifier is typically a session id, process id
or similar.
:param cursor: Cursor instance in which the query will be executed
:param query: Query instance
:return: Type of the payload can vary depends on databases
but must be jsonable. None if query can't be cancelled.
:return: Query identifier
"""
return None

@classmethod
def cancel_query(cls, cursor: Any, query: Query, payload: Any) -> None:
def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> None:
"""
Cancels query in the underlying database.
The method is called only when payload is not None.
Cancel query in the underlying database.
:param cursor: New cursor instance to the db of the query
:param query: Query instance
:param payload: Value returned by get_cancel_query_payload or set in
:param cancel_query_id: Value returned by get_cancel_query_payload or set in
other life-cycle methods of the query
"""

Expand Down
21 changes: 18 additions & 3 deletions superset/db_engine_specs/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,26 @@ def get_column_spec( # type: ignore
)

@classmethod
def get_cancel_query_payload(cls, cursor: Any, query: Query) -> Any:
def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]:
"""
Get MySQL connection ID that will be used to cancel all other running
queries in the same connection.
:param cursor: Cursor instance in which the query will be executed
:param query: Query instance
:return: MySQL Connection ID
"""
cursor.execute("SELECT CONNECTION_ID()")
row = cursor.fetchone()
return row[0]

@classmethod
def cancel_query(cls, cursor: Any, query: Query, payload: Any) -> None:
cursor.execute("KILL CONNECTION %d" % payload)
def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> None:
"""
Cancel query in the underlying database.
:param cursor: New cursor instance to the db of the query
:param query: Query instance
:param cancel_query_id: MySQL Connection ID
"""
cursor.execute("KILL CONNECTION %s" % cancel_query_id)
21 changes: 18 additions & 3 deletions superset/db_engine_specs/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,30 @@ def get_column_spec( # type: ignore
)

@classmethod
def get_cancel_query_payload(cls, cursor: Any, query: Query) -> Any:
def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]:
"""
Get Postgres PID that will be used to cancel all other running
queries in the same session.
:param cursor: Cursor instance in which the query will be executed
:param query: Query instance
:return: Postgres PID
"""
cursor.execute("SELECT pg_backend_pid()")
row = cursor.fetchone()
return row[0]

@classmethod
def cancel_query(cls, cursor: Any, query: Query, payload: Any) -> None:
def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> None:
"""
Cancel query in the underlying database.
:param cursor: New cursor instance to the db of the query
:param query: Query instance
:param cancel_query_id: Postgres PID
"""
cursor.execute(
"SELECT pg_terminate_backend(pid) "
"FROM pg_stat_activity "
"WHERE pid='%s'" % payload
"WHERE pid='%s'" % cancel_query_id
)
21 changes: 18 additions & 3 deletions superset/db_engine_specs/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,26 @@ def mutate_db_for_connection_test(database: "Database") -> None:
database.extra = json.dumps(extra)

@classmethod
def get_cancel_query_payload(cls, cursor: Any, query: Query) -> Any:
def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]:
"""
Get Snowflake session ID that will be used to cancel all other running
queries in the same session.
:param cursor: Cursor instance in which the query will be executed
:param query: Query instance
:return: Snowflake Session ID
"""
cursor.execute("SELECT CURRENT_SESSION()")
row = cursor.fetchone()
return row[0]

@classmethod
def cancel_query(cls, cursor: Any, query: Query, payload: Any) -> None:
cursor.execute("SELECT SYSTEM$CANCEL_ALL_QUERIES(%s)" % payload)
def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> None:
"""
Cancel query in the underlying database.
:param cursor: New cursor instance to the db of the query
:param query: Query instance
:param cancel_query_id: Snowflake Session ID
"""
cursor.execute("SELECT SYSTEM$CANCEL_ALL_QUERIES(%s)" % cancel_query_id)
14 changes: 7 additions & 7 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def dummy_sql_query_mutator(
SQL_QUERY_MUTATOR = config.get("SQL_QUERY_MUTATOR") or dummy_sql_query_mutator
log_query = config["QUERY_LOGGER"]
logger = logging.getLogger(__name__)
cancel_payload_key = "cancel_payload"
cancel_query_key = "cancel_query"


class SqlLabException(Exception):
Expand Down Expand Up @@ -449,9 +449,9 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
with closing(engine.raw_connection()) as conn:
# closing the connection closes the cursor as well
cursor = conn.cursor()
cancel_query_payload = db_engine_spec.get_cancel_query_payload(cursor, query)
if cancel_query_payload is not None:
query.set_extra_json_key(cancel_payload_key, cancel_query_payload)
cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
if cancel_query_id is not None:
query.set_extra_json_key(cancel_query_key, cancel_query_id)
session.commit()
statement_count = len(statements)
for i, statement in enumerate(statements):
Expand Down Expand Up @@ -591,8 +591,8 @@ def cancel_query(query: Query, user_name: Optional[str] = None) -> None:
:param user_name: Default username
:return: None
"""
cancel_payload = query.extra.get(cancel_payload_key, None)
if cancel_payload is None:
cancel_query_id = query.extra.get(cancel_query_key, None)
if cancel_query_id is None:
return

database = query.database
Expand All @@ -606,4 +606,4 @@ def cancel_query(query: Query, user_name: Optional[str] = None) -> None:

with closing(engine.raw_connection()) as conn:
with closing(conn.cursor()) as cursor:
db_engine_spec.cancel_query(cursor, query, cancel_payload)
db_engine_spec.cancel_query(cursor, query, cancel_query_id)
4 changes: 2 additions & 2 deletions tests/integration_tests/db_engine_specs/mysql_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,11 @@ def test_extract_errors(self):
]

@unittest.mock.patch("sqlalchemy.engine.Engine.connect")
def test_get_cancel_query_payload(self, engine_mock):
def test_get_cancel_query_id(self, engine_mock):
query = Query()
cursor_mock = engine_mock.return_value.__enter__.return_value
cursor_mock.fetchone.return_value = [123]
assert MySQLEngineSpec.get_cancel_query_payload(cursor_mock, query) == 123
assert MySQLEngineSpec.get_cancel_query_id(cursor_mock, query) == 123

@unittest.mock.patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query(self, engine_mock):
Expand Down
11 changes: 4 additions & 7 deletions tests/integration_tests/db_engine_specs/postgres_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,20 +445,17 @@ def test_extract_errors(self):
]

@mock.patch("sqlalchemy.engine.Engine.connect")
def test_get_cancel_query_payload(self, engine_mock):
def test_get_cancel_query_id(self, engine_mock):
query = Query()
cursor_mock = engine_mock.return_value.__enter__.return_value
cursor_mock.fetchone.return_value = ["testuser"]
assert (
PostgresEngineSpec.get_cancel_query_payload(cursor_mock, query)
== "testuser"
)
cursor_mock.fetchone.return_value = [123]
assert PostgresEngineSpec.get_cancel_query_id(cursor_mock, query) == 123

@mock.patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query(self, engine_mock):
query = Query()
cursor_mock = engine_mock.return_value.__enter__.return_value
assert PostgresEngineSpec.cancel_query(cursor_mock, query, "testuser") is None
assert PostgresEngineSpec.cancel_query(cursor_mock, query, 123) is None


def test_base_parameters_mixin():
Expand Down
6 changes: 3 additions & 3 deletions tests/integration_tests/db_engine_specs/snowflake_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,14 @@ def test_extract_errors(self):
]

@mock.patch("sqlalchemy.engine.Engine.connect")
def test_get_cancel_query_payload(self, engine_mock):
def test_get_cancel_query_id(self, engine_mock):
query = Query()
cursor_mock = engine_mock.return_value.__enter__.return_value
cursor_mock.fetchone.return_value = [123]
assert SnowflakeEngineSpec.get_cancel_query_payload(cursor_mock, query) == 123
assert SnowflakeEngineSpec.get_cancel_query_id(cursor_mock, query) == 123

@mock.patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query(self, engine_mock):
query = Query()
cursor_mock = engine_mock.return_value.__enter__.return_value
assert SnowflakeEngineSpec.cancel_query(cursor_mock, query, None) is None
assert SnowflakeEngineSpec.cancel_query(cursor_mock, query, 123) is None

0 comments on commit ef8436b

Please sign in to comment.