From 8b38aa4f9c2bb3db384d1daf04171f465b4c62e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Fri, 23 Jun 2023 11:25:07 -0700 Subject: [PATCH] bugfix: break down run+wait method in ECS operator This method is just causing trouble by handling several things, it's hiding the logic. A bug fixed in #31838 was reintroduced in #31881 because the check that was skipped on `wait_for_completion` was not skipped anymore. The bug is that checking the status will always fail if not waiting for completion, because obviously the task is not ready just after creation. --- airflow/providers/amazon/aws/operators/ecs.py | 77 +++++++++---------- 1 file changed, 38 insertions(+), 39 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index 2c2e93af3582a..91533cfa62112 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -539,46 +539,8 @@ def execute(self, context, session=None): if self.reattach: self._try_reattach_task(context) - self._start_wait_task(context) - - self._after_execution(session) - - if self.do_xcom_push and self.task_log_fetcher: - return self.task_log_fetcher.get_last_log_message() - else: - return None - - def execute_complete(self, context, event=None): - if event["status"] != "success": - raise AirflowException(f"Error in task execution: {event}") - self.arn = event["task_arn"] # restore arn to its updated value, needed for next steps - self._after_execution() - if self._aws_logs_enabled(): - # same behavior as non-deferrable mode, return last line of logs of the task. - logs_client = AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region).conn - one_log = logs_client.get_log_events( - logGroupName=self.awslogs_group, - logStreamName=self._get_logs_stream_name(), - startFromHead=False, - limit=1, - ) - if len(one_log["events"]) > 0: - return one_log["events"][0]["message"] - - @provide_session - def _after_execution(self, session=None): - self._check_success_task() - - self.log.info("ECS Task has been successfully executed") - - if self.reattach: - # Clear the XCom value storing the ECS task ARN if the task has completed - # as we can't reattach it anymore - self._xcom_del(session, self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id)) - - @AwsBaseHook.retry(should_retry_eni) - def _start_wait_task(self, context): if not self.arn: + # start the task except if we reattached to an existing one just before. self._start_task(context) if self.deferrable: @@ -598,6 +560,7 @@ def _start_wait_task(self, context): # 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent) timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60), ) + # self.defer raises a special exception, so execution stops here in this case. if not self.wait_for_completion: return @@ -615,9 +578,45 @@ def _start_wait_task(self, context): else: self._wait_for_task_ended() + self._after_execution(session) + + if self.do_xcom_push and self.task_log_fetcher: + return self.task_log_fetcher.get_last_log_message() + else: + return None + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error in task execution: {event}") + self.arn = event["task_arn"] # restore arn to its updated value, needed for next steps + self._after_execution() + if self._aws_logs_enabled(): + # same behavior as non-deferrable mode, return last line of logs of the task. + logs_client = AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region).conn + one_log = logs_client.get_log_events( + logGroupName=self.awslogs_group, + logStreamName=self._get_logs_stream_name(), + startFromHead=False, + limit=1, + ) + if len(one_log["events"]) > 0: + return one_log["events"][0]["message"] + + @provide_session + def _after_execution(self, session=None): + self._check_success_task() + + self.log.info("ECS Task has been successfully executed") + + if self.reattach: + # Clear the XCom value storing the ECS task ARN if the task has completed + # as we can't reattach it anymore + self._xcom_del(session, self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id)) + def _xcom_del(self, session, task_id): session.query(XCom).filter(XCom.dag_id == self.dag_id, XCom.task_id == task_id).delete() + @AwsBaseHook.retry(should_retry_eni) def _start_task(self, context): run_opts = { "cluster": self.cluster,