Skip to content

Commit

Permalink
bugfix: break down run+wait method in ECS operator (apache#32104)
Browse files Browse the repository at this point in the history
This method is just causing trouble by handling several things, it's hiding the logic.
A bug fixed in apache#31838 was reintroduced in apache#31881 because the check that was skipped on `wait_for_completion` was not skipped anymore.

The bug is that checking the status will always fail if not waiting for completion, because obviously the task is not ready just after creation.
  • Loading branch information
vandonr-amz authored and ferruzzi committed Jun 27, 2023
1 parent 77f3e6d commit 4ac51e9
Showing 1 changed file with 38 additions and 39 deletions.
77 changes: 38 additions & 39 deletions airflow/providers/amazon/aws/operators/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,46 +539,8 @@ def execute(self, context, session=None):
if self.reattach:
self._try_reattach_task(context)

self._start_wait_task(context)

self._after_execution(session)

if self.do_xcom_push and self.task_log_fetcher:
return self.task_log_fetcher.get_last_log_message()
else:
return None

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error in task execution: {event}")
self.arn = event["task_arn"] # restore arn to its updated value, needed for next steps
self._after_execution()
if self._aws_logs_enabled():
# same behavior as non-deferrable mode, return last line of logs of the task.
logs_client = AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region).conn
one_log = logs_client.get_log_events(
logGroupName=self.awslogs_group,
logStreamName=self._get_logs_stream_name(),
startFromHead=False,
limit=1,
)
if len(one_log["events"]) > 0:
return one_log["events"][0]["message"]

@provide_session
def _after_execution(self, session=None):
self._check_success_task()

self.log.info("ECS Task has been successfully executed")

if self.reattach:
# Clear the XCom value storing the ECS task ARN if the task has completed
# as we can't reattach it anymore
self._xcom_del(session, self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id))

@AwsBaseHook.retry(should_retry_eni)
def _start_wait_task(self, context):
if not self.arn:
# start the task except if we reattached to an existing one just before.
self._start_task(context)

if self.deferrable:
Expand All @@ -598,6 +560,7 @@ def _start_wait_task(self, context):
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
)
# self.defer raises a special exception, so execution stops here in this case.

if not self.wait_for_completion:
return
Expand All @@ -615,9 +578,45 @@ def _start_wait_task(self, context):
else:
self._wait_for_task_ended()

self._after_execution(session)

if self.do_xcom_push and self.task_log_fetcher:
return self.task_log_fetcher.get_last_log_message()
else:
return None

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error in task execution: {event}")
self.arn = event["task_arn"] # restore arn to its updated value, needed for next steps
self._after_execution()
if self._aws_logs_enabled():
# same behavior as non-deferrable mode, return last line of logs of the task.
logs_client = AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region).conn
one_log = logs_client.get_log_events(
logGroupName=self.awslogs_group,
logStreamName=self._get_logs_stream_name(),
startFromHead=False,
limit=1,
)
if len(one_log["events"]) > 0:
return one_log["events"][0]["message"]

@provide_session
def _after_execution(self, session=None):
self._check_success_task()

self.log.info("ECS Task has been successfully executed")

if self.reattach:
# Clear the XCom value storing the ECS task ARN if the task has completed
# as we can't reattach it anymore
self._xcom_del(session, self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id))

def _xcom_del(self, session, task_id):
session.query(XCom).filter(XCom.dag_id == self.dag_id, XCom.task_id == task_id).delete()

@AwsBaseHook.retry(should_retry_eni)
def _start_task(self, context):
run_opts = {
"cluster": self.cluster,
Expand Down

0 comments on commit 4ac51e9

Please sign in to comment.