diff --git a/airflow/providers/microsoft/azure/hooks/batch.py b/airflow/providers/microsoft/azure/hooks/batch.py index 7b681b3a9d7f3..2d26cc9f03f71 100644 --- a/airflow/providers/microsoft/azure/hooks/batch.py +++ b/airflow/providers/microsoft/azure/hooks/batch.py @@ -378,7 +378,7 @@ def wait_for_job_tasks_to_complete(self, job_id: str, timeout: int) -> list[batc """ timeout_time = timezone.utcnow() + timedelta(minutes=timeout) while timezone.utcnow() < timeout_time: - tasks = self.connection.task.list(job_id) + tasks = list(self.connection.task.list(job_id)) incomplete_tasks = [task for task in tasks if task.state != batch_models.TaskState.completed] if not incomplete_tasks: @@ -386,7 +386,7 @@ def wait_for_job_tasks_to_complete(self, job_id: str, timeout: int) -> list[batc fail_tasks = [ task for task in tasks - if task.executionInfo.result == batch_models.TaskExecutionResult.failure + if task.execution_info.result == batch_models.TaskExecutionResult.failure ] return fail_tasks for task in incomplete_tasks: diff --git a/tests/providers/microsoft/azure/hooks/test_batch.py b/tests/providers/microsoft/azure/hooks/test_batch.py index d9daa2f13d435..dd696911b2122 100644 --- a/tests/providers/microsoft/azure/hooks/test_batch.py +++ b/tests/providers/microsoft/azure/hooks/test_batch.py @@ -190,9 +190,61 @@ def test_add_single_task_to_job(self, mock_batch): mock_instance.assert_called_once_with(job_id="myjob", task=task) @mock.patch(f"{MODULE}.BatchServiceClient") - def test_wait_for_all_task_to_complete(self, mock_batch): - # TODO: Add test - pass + def test_wait_for_all_task_to_complete_timeout(self, mock_batch): + hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id) + with pytest.raises(TimeoutError): + hook.wait_for_job_tasks_to_complete("myjob", -1) + + @mock.patch(f"{MODULE}.BatchServiceClient") + def test_wait_for_all_task_to_complete_all_success(self, mock_batch): + hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id) + hook.connection.task.list.return_value = iter( + [ + batch_models.CloudTask( + id="mytask_1", + execution_info=batch_models.TaskExecutionInformation( + retry_count=0, requeue_count=0, result=batch_models.TaskExecutionResult.success + ), + state=batch_models.TaskState.completed, + ), + batch_models.CloudTask( + id="mytask_2", + execution_info=batch_models.TaskExecutionInformation( + retry_count=0, requeue_count=0, result=batch_models.TaskExecutionResult.success + ), + state=batch_models.TaskState.completed, + ), + ] + ) + + results = hook.wait_for_job_tasks_to_complete("myjob", 60) + assert results == [] + hook.connection.task.list.assert_called_once_with("myjob") + + @mock.patch(f"{MODULE}.BatchServiceClient") + def test_wait_for_all_task_to_complete_failures(self, mock_batch): + hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id) + tasks = [ + batch_models.CloudTask( + id="mytask_1", + execution_info=batch_models.TaskExecutionInformation( + retry_count=0, requeue_count=0, result=batch_models.TaskExecutionResult.success + ), + state=batch_models.TaskState.completed, + ), + batch_models.CloudTask( + id="mytask_2", + execution_info=batch_models.TaskExecutionInformation( + retry_count=0, requeue_count=0, result=batch_models.TaskExecutionResult.failure + ), + state=batch_models.TaskState.completed, + ), + ] + hook.connection.task.list.return_value = iter(tasks) + + results = hook.wait_for_job_tasks_to_complete("myjob", 60) + assert results == [tasks[1]] + hook.connection.task.list.assert_called_once_with("myjob") @mock.patch(f"{MODULE}.BatchServiceClient") def test_connection_success(self, mock_batch):