Skip to content

Commit

Permalink
Fix logic to cancel the external job if the TaskInstance is not in a …
Browse files Browse the repository at this point in the history
…running or deferred state for DataprocSubmitJobOperator (#39447)
  • Loading branch information
sunank200 authored May 8, 2024
1 parent e7aa4d2 commit 387acd0
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 3 deletions.
41 changes: 40 additions & 1 deletion airflow/providers/google/cloud/triggers/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,41 @@ def serialize(self):
},
)

@provide_session
def get_task_instance(self, session: Session) -> TaskInstance:
"""
Get the task instance for the current task.
:param session: Sqlalchemy session
"""
query = session.query(TaskInstance).filter(
TaskInstance.dag_id == self.task_instance.dag_id,
TaskInstance.task_id == self.task_instance.task_id,
TaskInstance.run_id == self.task_instance.run_id,
TaskInstance.map_index == self.task_instance.map_index,
)
task_instance = query.one_or_none()
if task_instance is None:
raise AirflowException(
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found",
self.task_instance.dag_id,
self.task_instance.task_id,
self.task_instance.run_id,
self.task_instance.map_index,
)
return task_instance

def safe_to_cancel(self) -> bool:
"""
Whether it is safe to cancel the external job which is being executed by this trigger.
This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
Because in those cases, we should NOT cancel the external job.
"""
# Database query is needed to get the latest state of the task instance.
task_instance = self.get_task_instance() # type: ignore[call-arg]
return task_instance.state != TaskInstanceState.DEFERRED

async def run(self):
try:
while True:
Expand All @@ -131,7 +166,11 @@ async def run(self):
except asyncio.CancelledError:
self.log.info("Task got cancelled.")
try:
if self.job_id and self.cancel_on_kill:
if self.job_id and self.cancel_on_kill and self.safe_to_cancel():
self.log.info(
"Cancelling the job as it is safe to do so. Note that the airflow TaskInstance is not"
" in deferred state."
)
self.log.info("Cancelling the job: %s", self.job_id)
# The synchronous hook is utilized to delete the cluster when a task is cancelled. This
# is because the asynchronous hook deletion is not awaited when the trigger task is
Expand Down
8 changes: 6 additions & 2 deletions tests/providers/google/cloud/triggers/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def submit_trigger():
region=TEST_REGION,
gcp_conn_id=TEST_GCP_CONN_ID,
polling_interval_seconds=TEST_POLL_INTERVAL,
cancel_on_kill=True,
)


Expand Down Expand Up @@ -569,12 +570,15 @@ async def test_submit_trigger_run_error(self, mock_get_async_hook, submit_trigge
assert event.payload == expected_event.payload

@pytest.mark.asyncio
@pytest.mark.parametrize("is_safe_to_cancel", [True, False])
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_async_hook")
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_sync_hook")
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.safe_to_cancel")
async def test_submit_trigger_run_cancelled(
self, mock_get_sync_hook, mock_get_async_hook, submit_trigger
self, mock_safe_to_cancel, mock_get_sync_hook, mock_get_async_hook, submit_trigger, is_safe_to_cancel
):
"""Test the trigger correctly handles an asyncio.CancelledError."""
mock_safe_to_cancel.return_value = is_safe_to_cancel
mock_async_hook = mock_get_async_hook.return_value
mock_async_hook.get_job.side_effect = asyncio.CancelledError

Expand All @@ -598,7 +602,7 @@ async def test_submit_trigger_run_cancelled(
pytest.fail(f"Unexpected exception raised: {e}")

# Check if cancel_job was correctly called
if submit_trigger.cancel_on_kill:
if submit_trigger.cancel_on_kill and is_safe_to_cancel:
mock_sync_hook.cancel_job.assert_called_once_with(
job_id=submit_trigger.job_id,
project_id=submit_trigger.project_id,
Expand Down

0 comments on commit 387acd0

Please sign in to comment.