diff --git a/airflow/providers/google/cloud/triggers/bigquery.py b/airflow/providers/google/cloud/triggers/bigquery.py index e2e0e82f6b0eb..fc19db988126f 100644 --- a/airflow/providers/google/cloud/triggers/bigquery.py +++ b/airflow/providers/google/cloud/triggers/bigquery.py @@ -17,13 +17,20 @@ from __future__ import annotations import asyncio -from typing import Any, AsyncIterator, Sequence, SupportsAbs +from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence, SupportsAbs from aiohttp import ClientSession from aiohttp.client_exceptions import ClientResponseError +from airflow.exceptions import AirflowException +from airflow.models.taskinstance import TaskInstance from airflow.providers.google.cloud.hooks.bigquery import BigQueryAsyncHook, BigQueryTableAsyncHook from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.utils.session import provide_session +from airflow.utils.state import TaskInstanceState + +if TYPE_CHECKING: + from sqlalchemy.orm.session import Session class BigQueryInsertJobTrigger(BaseTrigger): @@ -89,6 +96,36 @@ def serialize(self) -> tuple[str, dict[str, Any]]: }, ) + @provide_session + def get_task_instance(self, session: Session) -> TaskInstance: + query = session.query(TaskInstance).filter( + TaskInstance.dag_id == self.task_instance.dag_id, + TaskInstance.task_id == self.task_instance.task_id, + TaskInstance.run_id == self.task_instance.run_id, + TaskInstance.map_index == self.task_instance.map_index, + ) + task_instance = query.one_or_none() + if task_instance is None: + raise AirflowException( + "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found", + self.task_instance.dag_id, + self.task_instance.task_id, + self.task_instance.run_id, + self.task_instance.map_index, + ) + return task_instance + + def safe_to_cancel(self) -> bool: + """ + Whether it is safe to cancel the external job which is being executed by this trigger. + + This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped. + Because in those cases, we should NOT cancel the external job. + """ + # Database query is needed to get the latest state of the task instance. + task_instance = self.get_task_instance() # type: ignore[call-arg] + return task_instance.state != TaskInstanceState.DEFERRED + async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] """Get current job execution status and yields a TriggerEvent.""" hook = self._get_async_hook() @@ -117,13 +154,27 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] ) await asyncio.sleep(self.poll_interval) except asyncio.CancelledError: - self.log.info("Task was killed.") - if self.job_id and self.cancel_on_kill: + if self.job_id and self.cancel_on_kill and self.safe_to_cancel(): + self.log.info( + "The job is safe to cancel the as airflow TaskInstance is not in deferred state." + ) + self.log.info( + "Cancelling job. Project ID: %s, Location: %s, Job ID: %s", + self.project_id, + self.location, + self.job_id, + ) await hook.cancel_job( # type: ignore[union-attr] job_id=self.job_id, project_id=self.project_id, location=self.location ) else: - self.log.info("Skipping to cancel job: %s:%s.%s", self.project_id, self.location, self.job_id) + self.log.info( + "Trigger may have shutdown. Skipping to cancel job because the airflow " + "task is not cancelled yet: Project ID: %s, Location:%s, Job ID:%s", + self.project_id, + self.location, + self.job_id, + ) except Exception as e: self.log.exception("Exception occurred while checking for query completion") yield TriggerEvent({"status": "error", "message": str(e)}) @@ -148,6 +199,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "table_id": self.table_id, "poll_interval": self.poll_interval, "impersonation_chain": self.impersonation_chain, + "cancel_on_kill": self.cancel_on_kill, }, ) diff --git a/tests/providers/google/cloud/triggers/test_bigquery.py b/tests/providers/google/cloud/triggers/test_bigquery.py index 436872903eb5b..bbb1a50356882 100644 --- a/tests/providers/google/cloud/triggers/test_bigquery.py +++ b/tests/providers/google/cloud/triggers/test_bigquery.py @@ -239,13 +239,15 @@ async def test_bigquery_op_trigger_exception(self, mock_job_status, caplog, inse @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.cancel_job") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") + @mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger.safe_to_cancel") async def test_bigquery_insert_job_trigger_cancellation( - self, mock_get_job_status, mock_cancel_job, caplog, insert_job_trigger + self, mock_get_task_instance, mock_get_job_status, mock_cancel_job, caplog, insert_job_trigger ): """ Test that BigQueryInsertJobTrigger handles cancellation correctly, logs the appropriate message, and conditionally cancels the job based on the `cancel_on_kill` attribute. """ + mock_get_task_instance.return_value = True insert_job_trigger.cancel_on_kill = True insert_job_trigger.job_id = "1234" @@ -271,6 +273,41 @@ async def test_bigquery_insert_job_trigger_cancellation( ), "Expected messages about task status or cancellation not found in log." mock_cancel_job.assert_awaited_once() + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.cancel_job") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") + @mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger.safe_to_cancel") + async def test_bigquery_insert_job_trigger_cancellation_unsafe_cancellation( + self, mock_safe_to_cancel, mock_get_job_status, mock_cancel_job, caplog, insert_job_trigger + ): + """ + Test that BigQueryInsertJobTrigger logs the appropriate message and does not cancel the job + if safe_to_cancel returns False even when the task is cancelled. + """ + mock_safe_to_cancel.return_value = False + insert_job_trigger.cancel_on_kill = True + insert_job_trigger.job_id = "1234" + + # Simulate the initial job status as running + mock_get_job_status.side_effect = [ + {"status": "running", "message": "Job is still running"}, + asyncio.CancelledError(), + {"status": "running", "message": "Job is still running after cancellation"}, + ] + + caplog.set_level(logging.INFO) + + try: + async for _ in insert_job_trigger.run(): + pass + except asyncio.CancelledError: + pass + + assert ( + "Skipping to cancel job" in caplog.text + ), "Expected message about skipping cancellation not found in log." + assert mock_get_job_status.call_count == 2, "Job status should be checked multiple times" + class TestBigQueryGetDataTrigger: def test_bigquery_get_data_trigger_serialization(self, get_data_trigger): @@ -447,6 +484,7 @@ def test_check_trigger_serialization(self, check_trigger): "table_id": TEST_TABLE_ID, "location": None, "poll_interval": POLLING_PERIOD_SECONDS, + "cancel_on_kill": True, } @pytest.mark.asyncio