Skip to content

Commit

Permalink
Fix failed tasks are not detected in AzureBatchHook (#36785)
Browse files Browse the repository at this point in the history
* fix wait_for_job_tasks_to_complete in microsoft-azure provider

* add unit tests for wait_for_all_tasks_to_complete

* fix azure batch hook unit tests

* fix azure batch hook tests
  • Loading branch information
ArvidMartensRenson authored Jan 16, 2024
1 parent 81be6ac commit 57c9211
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 5 deletions.
4 changes: 2 additions & 2 deletions airflow/providers/microsoft/azure/hooks/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,15 +378,15 @@ def wait_for_job_tasks_to_complete(self, job_id: str, timeout: int) -> list[batc
"""
timeout_time = timezone.utcnow() + timedelta(minutes=timeout)
while timezone.utcnow() < timeout_time:
tasks = self.connection.task.list(job_id)
tasks = list(self.connection.task.list(job_id))

incomplete_tasks = [task for task in tasks if task.state != batch_models.TaskState.completed]
if not incomplete_tasks:
# detect if any task in job has failed
fail_tasks = [
task
for task in tasks
if task.executionInfo.result == batch_models.TaskExecutionResult.failure
if task.execution_info.result == batch_models.TaskExecutionResult.failure
]
return fail_tasks
for task in incomplete_tasks:
Expand Down
58 changes: 55 additions & 3 deletions tests/providers/microsoft/azure/hooks/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,61 @@ def test_add_single_task_to_job(self, mock_batch):
mock_instance.assert_called_once_with(job_id="myjob", task=task)

@mock.patch(f"{MODULE}.BatchServiceClient")
def test_wait_for_all_task_to_complete(self, mock_batch):
# TODO: Add test
pass
def test_wait_for_all_task_to_complete_timeout(self, mock_batch):
hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id)
with pytest.raises(TimeoutError):
hook.wait_for_job_tasks_to_complete("myjob", -1)

@mock.patch(f"{MODULE}.BatchServiceClient")
def test_wait_for_all_task_to_complete_all_success(self, mock_batch):
hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id)
hook.connection.task.list.return_value = iter(
[
batch_models.CloudTask(
id="mytask_1",
execution_info=batch_models.TaskExecutionInformation(
retry_count=0, requeue_count=0, result=batch_models.TaskExecutionResult.success
),
state=batch_models.TaskState.completed,
),
batch_models.CloudTask(
id="mytask_2",
execution_info=batch_models.TaskExecutionInformation(
retry_count=0, requeue_count=0, result=batch_models.TaskExecutionResult.success
),
state=batch_models.TaskState.completed,
),
]
)

results = hook.wait_for_job_tasks_to_complete("myjob", 60)
assert results == []
hook.connection.task.list.assert_called_once_with("myjob")

@mock.patch(f"{MODULE}.BatchServiceClient")
def test_wait_for_all_task_to_complete_failures(self, mock_batch):
hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id)
tasks = [
batch_models.CloudTask(
id="mytask_1",
execution_info=batch_models.TaskExecutionInformation(
retry_count=0, requeue_count=0, result=batch_models.TaskExecutionResult.success
),
state=batch_models.TaskState.completed,
),
batch_models.CloudTask(
id="mytask_2",
execution_info=batch_models.TaskExecutionInformation(
retry_count=0, requeue_count=0, result=batch_models.TaskExecutionResult.failure
),
state=batch_models.TaskState.completed,
),
]
hook.connection.task.list.return_value = iter(tasks)

results = hook.wait_for_job_tasks_to_complete("myjob", 60)
assert results == [tasks[1]]
hook.connection.task.list.assert_called_once_with("myjob")

@mock.patch(f"{MODULE}.BatchServiceClient")
def test_connection_success(self, mock_batch):
Expand Down

0 comments on commit 57c9211

Please sign in to comment.