Skip to content

Commit

Permalink
Fix logic to cancel the external job if the TaskInstance is not in a …
Browse files Browse the repository at this point in the history
…running or deferred state for BigQueryInsertJobOperator (#39442)
  • Loading branch information
sunank200 authored May 8, 2024
1 parent 73587ba commit e7aa4d2
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 5 deletions.
60 changes: 56 additions & 4 deletions airflow/providers/google/cloud/triggers/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,20 @@
from __future__ import annotations

import asyncio
from typing import Any, AsyncIterator, Sequence, SupportsAbs
from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence, SupportsAbs

from aiohttp import ClientSession
from aiohttp.client_exceptions import ClientResponseError

from airflow.exceptions import AirflowException
from airflow.models.taskinstance import TaskInstance
from airflow.providers.google.cloud.hooks.bigquery import BigQueryAsyncHook, BigQueryTableAsyncHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.session import provide_session
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from sqlalchemy.orm.session import Session


class BigQueryInsertJobTrigger(BaseTrigger):
Expand Down Expand Up @@ -89,6 +96,36 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
},
)

@provide_session
def get_task_instance(self, session: Session) -> TaskInstance:
query = session.query(TaskInstance).filter(
TaskInstance.dag_id == self.task_instance.dag_id,
TaskInstance.task_id == self.task_instance.task_id,
TaskInstance.run_id == self.task_instance.run_id,
TaskInstance.map_index == self.task_instance.map_index,
)
task_instance = query.one_or_none()
if task_instance is None:
raise AirflowException(
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
self.task_instance.dag_id,
self.task_instance.task_id,
self.task_instance.run_id,
self.task_instance.map_index,
)
return task_instance

def safe_to_cancel(self) -> bool:
"""
Whether it is safe to cancel the external job which is being executed by this trigger.
This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
Because in those cases, we should NOT cancel the external job.
"""
# Database query is needed to get the latest state of the task instance.
task_instance = self.get_task_instance() # type: ignore[call-arg]
return task_instance.state != TaskInstanceState.DEFERRED

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Get current job execution status and yields a TriggerEvent."""
hook = self._get_async_hook()
Expand Down Expand Up @@ -117,13 +154,27 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
)
await asyncio.sleep(self.poll_interval)
except asyncio.CancelledError:
self.log.info("Task was killed.")
if self.job_id and self.cancel_on_kill:
if self.job_id and self.cancel_on_kill and self.safe_to_cancel():
self.log.info(
"The job is safe to cancel the as airflow TaskInstance is not in deferred state."
)
self.log.info(
"Cancelling job. Project ID: %s, Location: %s, Job ID: %s",
self.project_id,
self.location,
self.job_id,
)
await hook.cancel_job( # type: ignore[union-attr]
job_id=self.job_id, project_id=self.project_id, location=self.location
)
else:
self.log.info("Skipping to cancel job: %s:%s.%s", self.project_id, self.location, self.job_id)
self.log.info(
"Trigger may have shutdown. Skipping to cancel job because the airflow "
"task is not cancelled yet: Project ID: %s, Location:%s, Job ID:%s",
self.project_id,
self.location,
self.job_id,
)
except Exception as e:
self.log.exception("Exception occurred while checking for query completion")
yield TriggerEvent({"status": "error", "message": str(e)})
Expand All @@ -148,6 +199,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"table_id": self.table_id,
"poll_interval": self.poll_interval,
"impersonation_chain": self.impersonation_chain,
"cancel_on_kill": self.cancel_on_kill,
},
)

Expand Down
40 changes: 39 additions & 1 deletion tests/providers/google/cloud/triggers/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,15 @@ async def test_bigquery_op_trigger_exception(self, mock_job_status, caplog, inse
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.cancel_job")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger.safe_to_cancel")
async def test_bigquery_insert_job_trigger_cancellation(
self, mock_get_job_status, mock_cancel_job, caplog, insert_job_trigger
self, mock_get_task_instance, mock_get_job_status, mock_cancel_job, caplog, insert_job_trigger
):
"""
Test that BigQueryInsertJobTrigger handles cancellation correctly, logs the appropriate message,
and conditionally cancels the job based on the `cancel_on_kill` attribute.
"""
mock_get_task_instance.return_value = True
insert_job_trigger.cancel_on_kill = True
insert_job_trigger.job_id = "1234"

Expand All @@ -271,6 +273,41 @@ async def test_bigquery_insert_job_trigger_cancellation(
), "Expected messages about task status or cancellation not found in log."
mock_cancel_job.assert_awaited_once()

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.cancel_job")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger.safe_to_cancel")
async def test_bigquery_insert_job_trigger_cancellation_unsafe_cancellation(
self, mock_safe_to_cancel, mock_get_job_status, mock_cancel_job, caplog, insert_job_trigger
):
"""
Test that BigQueryInsertJobTrigger logs the appropriate message and does not cancel the job
if safe_to_cancel returns False even when the task is cancelled.
"""
mock_safe_to_cancel.return_value = False
insert_job_trigger.cancel_on_kill = True
insert_job_trigger.job_id = "1234"

# Simulate the initial job status as running
mock_get_job_status.side_effect = [
{"status": "running", "message": "Job is still running"},
asyncio.CancelledError(),
{"status": "running", "message": "Job is still running after cancellation"},
]

caplog.set_level(logging.INFO)

try:
async for _ in insert_job_trigger.run():
pass
except asyncio.CancelledError:
pass

assert (
"Skipping to cancel job" in caplog.text
), "Expected message about skipping cancellation not found in log."
assert mock_get_job_status.call_count == 2, "Job status should be checked multiple times"


class TestBigQueryGetDataTrigger:
def test_bigquery_get_data_trigger_serialization(self, get_data_trigger):
Expand Down Expand Up @@ -447,6 +484,7 @@ def test_check_trigger_serialization(self, check_trigger):
"table_id": TEST_TABLE_ID,
"location": None,
"poll_interval": POLLING_PERIOD_SECONDS,
"cancel_on_kill": True,
}

@pytest.mark.asyncio
Expand Down

0 comments on commit e7aa4d2

Please sign in to comment.