Skip to content

Commit

Permalink
feat(FIR-43324): async cancellation method (#420)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptiurin authored Feb 11, 2025
1 parent 4d7008e commit 3d92df9
Show file tree
Hide file tree
Showing 9 changed files with 342 additions and 16 deletions.
17 changes: 17 additions & 0 deletions docsrc/Connecting_and_queries.rst
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,23 @@ has finished successfully, None if query is still running and False if the query
else:
print("Query failed")

Cancelling a running query
--------------------------

To cancel a running query, use the :py:meth:`firebolt.db.connection.Connection.cancel_async_query` method. This method
will send a cancel request to the server and the query will be stopped.

::

token = cursor.async_query_token
connection.cancel_async_query(token)
# Verify that the query was cancelled
running = connection.is_async_query_running(token)
print(running) # False
successful = connection.is_async_query_successful(token)
print(successful) # False


Thread safety
==============================
Expand Down
44 changes: 36 additions & 8 deletions src/firebolt/async_db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
from firebolt.client.auth import Auth
from firebolt.client.client import AsyncClient, AsyncClientV1, AsyncClientV2
from firebolt.common.base_connection import (
ASYNC_QUERY_CANCEL,
ASYNC_QUERY_STATUS_REQUEST,
ASYNC_QUERY_STATUS_RUNNING,
ASYNC_QUERY_STATUS_SUCCESSFUL,
AsyncQueryInfo,
BaseConnection,
)
from firebolt.common.cache import _firebolt_system_engine_cache
Expand Down Expand Up @@ -90,19 +92,33 @@ def cursor(self, **kwargs: Any) -> Cursor:
return c

# Server-side async methods
async def _get_async_query_status(self, token: str) -> str:
async def _get_async_query_info(self, token: str) -> AsyncQueryInfo:
if self.cursor_type != CursorV2:
raise FireboltError(
"This method is only supported for connection with service account."
)
cursor = self.cursor()
await cursor.execute(ASYNC_QUERY_STATUS_REQUEST.format(token=token))
await cursor.execute(ASYNC_QUERY_STATUS_REQUEST, [token])
result = await cursor.fetchone()
if cursor.rowcount != 1 or not result:
raise FireboltError("Unexpected result from async query status request.")
columns = cursor.description
result_dict = dict(zip([column.name for column in columns], result))
return str(result_dict.get("status"))

if not result_dict.get("status") or not result_dict.get("query_id"):
raise FireboltError(
"Something went wrong - async query status request returned "
"unexpected result with status and/or query id missing. "
"Rerun the command and reach out to Firebolt support if "
"the issue persists."
)

# Only pass the expected keys to AsyncQueryInfo
filtered_result_dict = {
k: v for k, v in result_dict.items() if k in AsyncQueryInfo._fields
}

return AsyncQueryInfo(**filtered_result_dict)

async def is_async_query_running(self, token: str) -> bool:
"""
Expand All @@ -114,8 +130,8 @@ async def is_async_query_running(self, token: str) -> bool:
Returns:
bool: True if async query is still running, False otherwise
"""
status = await self._get_async_query_status(token)
return status == ASYNC_QUERY_STATUS_RUNNING
async_query_details = await self._get_async_query_info(token)
return async_query_details.status == ASYNC_QUERY_STATUS_RUNNING

async def is_async_query_successful(self, token: str) -> Optional[bool]:
"""
Expand All @@ -128,10 +144,22 @@ async def is_async_query_successful(self, token: str) -> Optional[bool]:
bool: None if the query is still running, True if successful,
False otherwise
"""
status = await self._get_async_query_status(token)
if status == ASYNC_QUERY_STATUS_RUNNING:
async_query_details = await self._get_async_query_info(token)
if async_query_details.status == ASYNC_QUERY_STATUS_RUNNING:
return None
return status == ASYNC_QUERY_STATUS_SUCCESSFUL
return async_query_details.status == ASYNC_QUERY_STATUS_SUCCESSFUL

async def cancel_async_query(self, token: str) -> None:
"""
Cancel an async query.
Args:
token: Async query token. Can be obtained from Cursor.async_query_token.
"""
async_query_details = await self._get_async_query_info(token)
async_query_id = async_query_details.query_id
cursor = self.cursor()
await cursor.execute(ASYNC_QUERY_CANCEL, [async_query_id])

# Context manager support
async def __aenter__(self) -> Connection:
Expand Down
22 changes: 21 additions & 1 deletion src/firebolt/common/base_connection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,30 @@
from collections import namedtuple
from typing import Any, List, Type

from firebolt.utils.exception import ConnectionClosedError

ASYNC_QUERY_STATUS_RUNNING = "RUNNING"
ASYNC_QUERY_STATUS_SUCCESSFUL = "ENDED_SUCCESSFULLY"
ASYNC_QUERY_STATUS_REQUEST = "CALL fb_GetAsyncStatus('{token}')"
ASYNC_QUERY_STATUS_REQUEST = "CALL fb_GetAsyncStatus(?)"
ASYNC_QUERY_CANCEL = "CANCEL QUERY WHERE query_id=?"

AsyncQueryInfo = namedtuple(
"AsyncQueryInfo",
[
"account_name",
"user_name",
"submitted_time",
"start_time",
"end_time",
"status",
"request_id",
"query_id",
"error_message",
"scanned_bytes",
"scanned_rows",
"retries",
],
)


class BaseConnection:
Expand Down
42 changes: 35 additions & 7 deletions src/firebolt/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from firebolt.client import DEFAULT_API_URL, Client, ClientV1, ClientV2
from firebolt.client.auth import Auth
from firebolt.common.base_connection import (
ASYNC_QUERY_CANCEL,
ASYNC_QUERY_STATUS_REQUEST,
ASYNC_QUERY_STATUS_RUNNING,
ASYNC_QUERY_STATUS_SUCCESSFUL,
AsyncQueryInfo,
BaseConnection,
)
from firebolt.common.cache import _firebolt_system_engine_cache
Expand Down Expand Up @@ -227,19 +229,34 @@ def close(self) -> None:
self._is_closed = True

# Server-side async methods
def _get_async_query_status(self, token: str) -> str:

def _get_async_query_info(self, token: str) -> AsyncQueryInfo:
if self.cursor_type != CursorV2:
raise FireboltError(
"This method is only supported for connection with service account."
)
cursor = self.cursor()
cursor.execute(ASYNC_QUERY_STATUS_REQUEST.format(token=token))
cursor.execute(ASYNC_QUERY_STATUS_REQUEST, [token])
result = cursor.fetchone()
if cursor.rowcount != 1 or not result:
raise FireboltError("Unexpected result from async query status request.")
columns = cursor.description
result_dict = dict(zip([column.name for column in columns], result))
return result_dict["status"]

if not result_dict.get("status") or not result_dict.get("query_id"):
raise FireboltError(
"Something went wrong - async query status request returned "
"unexpected result with status and/or query id missing. "
"Rerun the command and reach out to Firebolt support if "
"the issue persists."
)

# Only pass the expected keys to AsyncQueryInfo
filtered_result_dict = {
k: v for k, v in result_dict.items() if k in AsyncQueryInfo._fields
}

return AsyncQueryInfo(**filtered_result_dict)

def is_async_query_running(self, token: str) -> bool:
"""
Expand All @@ -251,7 +268,7 @@ def is_async_query_running(self, token: str) -> bool:
Returns:
bool: True if async query is still running, False otherwise
"""
return self._get_async_query_status(token) == ASYNC_QUERY_STATUS_RUNNING
return self._get_async_query_info(token).status == ASYNC_QUERY_STATUS_RUNNING

def is_async_query_successful(self, token: str) -> Optional[bool]:
"""
Expand All @@ -264,10 +281,21 @@ def is_async_query_successful(self, token: str) -> Optional[bool]:
bool: None if the query is still running, True if successful,
False otherwise
"""
status = self._get_async_query_status(token)
if status == ASYNC_QUERY_STATUS_RUNNING:
async_query_info = self._get_async_query_info(token)
if async_query_info.status == ASYNC_QUERY_STATUS_RUNNING:
return None
return status == ASYNC_QUERY_STATUS_SUCCESSFUL
return async_query_info.status == ASYNC_QUERY_STATUS_SUCCESSFUL

def cancel_async_query(self, token: str) -> None:
"""
Cancel an async query.
Args:
token: Async query token. Can be obtained from Cursor.async_query_token.
"""
async_query_id = self._get_async_query_info(token).query_id
cursor = self.cursor()
cursor.execute(ASYNC_QUERY_CANCEL, [async_query_id])

# Context manager support
def __enter__(self) -> Connection:
Expand Down
20 changes: 20 additions & 0 deletions tests/integration/dbapi/async/V2/test_server_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,23 @@ async def test_check_async_execution_fails(connection: Connection) -> None:
await cursor.execute_async(f"MALFORMED QUERY")
with raises(FireboltError):
cursor.async_query_token


async def test_cancel_async_query(connection: Connection) -> None:
cursor = connection.cursor()
rnd_suffix = str(randint(0, 1000))
table_name = f"test_insert_async_{rnd_suffix}"
try:
await cursor.execute(f"CREATE TABLE {table_name} (id LONG)")
await cursor.execute_async(f"INSERT INTO {table_name} {LONG_SELECT}")
token = cursor.async_query_token
assert token is not None, "Async token was not returned"
assert await connection.is_async_query_running(token) == True
await connection.cancel_async_query(token)
assert await connection.is_async_query_running(token) == False
assert await connection.is_async_query_successful(token) == False
await cursor.execute(f"SELECT * FROM {table_name}")
result = await cursor.fetchall()
assert result == []
finally:
await cursor.execute(f"DROP TABLE {table_name}")
20 changes: 20 additions & 0 deletions tests/integration/dbapi/sync/V2/test_server_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,23 @@ def test_check_async_execution_fails(connection: Connection) -> None:
cursor.execute_async(f"MALFORMED QUERY")
with raises(FireboltError):
cursor.async_query_token


def test_cancel_async_query(connection: Connection) -> None:
cursor = connection.cursor()
rnd_suffix = str(randint(0, 1000))
table_name = f"test_insert_async_{rnd_suffix}"
try:
cursor.execute(f"CREATE TABLE {table_name} (id LONG)")
cursor.execute_async(f"INSERT INTO {table_name} {LONG_SELECT}")
token = cursor.async_query_token
assert token is not None, "Async token was not returned"
assert connection.is_async_query_running(token) == True
connection.cancel_async_query(token)
assert connection.is_async_query_running(token) == False
assert connection.is_async_query_successful(token) == False
cursor.execute(f"SELECT * FROM {table_name}")
result = cursor.fetchall()
assert result == []
finally:
cursor.execute(f"DROP TABLE {table_name}")
86 changes: 86 additions & 0 deletions tests/unit/async_db/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,3 +492,89 @@ async def test_async_query_status_unexpected_result(
await connection.is_async_query_running("token")
with raises(FireboltError):
await connection.is_async_query_successful("token")


async def test_async_query_status_no_id_or_status(
db_name: str,
account_name: str,
engine_name: str,
auth: Auth,
api_endpoint: str,
httpx_mock: HTTPXMock,
query_url: str,
async_query_callback_factory: Callable,
async_query_meta: List[Tuple[str, str]],
async_query_data: List[List[ColType]],
mock_connection_flow: Callable,
):
mock_connection_flow()
data_no_query_id = async_query_data[0].copy()
data_no_query_id[7] = ""
data_no_query_status = async_query_data[0].copy()
data_no_query_status[5] = ""
for data_case in [data_no_query_id, data_no_query_status]:
async_query_status_running_callback = async_query_callback_factory(
[data_case], async_query_meta
)
httpx_mock.add_callback(
async_query_status_running_callback,
url=query_url,
match_content="CALL fb_GetAsyncStatus('token')".encode("utf-8"),
)
async with await connect(
database=db_name,
auth=auth,
engine_name=engine_name,
account_name=account_name,
api_endpoint=api_endpoint,
) as connection:
with raises(FireboltError):
await connection.is_async_query_running("token")
with raises(FireboltError):
await connection.is_async_query_successful("token")


async def test_async_query_cancellation(
db_name: str,
account_name: str,
engine_name: str,
auth: Auth,
api_endpoint: str,
httpx_mock: HTTPXMock,
query_url: str,
query_callback: Callable,
async_query_callback_factory: Callable,
async_query_data: List[List[ColType]],
async_query_meta: List[Tuple[str, str]],
mock_connection_flow: Callable,
):
"""Test async query cancellation"""
mock_connection_flow()
async_query_data[0][5] = "RUNNING"
async_query_status_running_callback = async_query_callback_factory(
async_query_data, async_query_meta
)

query_dict = dict(zip([m[0] for m in async_query_meta], async_query_data[0]))
query_id = query_dict["query_id"]

httpx_mock.add_callback(
async_query_status_running_callback,
url=query_url,
match_content="CALL fb_GetAsyncStatus('token')".encode("utf-8"),
)

httpx_mock.add_callback(
query_callback,
url=query_url,
match_content=f"CANCEL QUERY WHERE query_id='{query_id}'".encode("utf-8"),
)

async with await connect(
database=db_name,
auth=auth,
engine_name=engine_name,
account_name=account_name,
api_endpoint=api_endpoint,
) as connection:
await connection.cancel_async_query("token")
Loading

0 comments on commit 3d92df9

Please sign in to comment.