Skip to content

Commit

Permalink
Deferrable mode for ECS operators (apache#31881)
Browse files Browse the repository at this point in the history
  • Loading branch information
vandonr-amz authored Jun 23, 2023
1 parent e4ca688 commit 415e076
Show file tree
Hide file tree
Showing 7 changed files with 532 additions and 33 deletions.
160 changes: 133 additions & 27 deletions airflow/providers/amazon/aws/operators/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
EcsHook,
should_retry_eni,
)
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.triggers.ecs import (
ClusterWaiterTrigger,
TaskDoneTrigger,
)
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
from airflow.utils.helpers import prune_dict
from airflow.utils.session import provide_session
Expand Down Expand Up @@ -67,6 +72,15 @@ def execute(self, context: Context):
"""Must overwrite in child classes."""
raise NotImplementedError("Please implement execute() in subclass")

def _complete_exec_with_cluster_desc(self, context, event=None):
"""To be used as trigger callback for operators that return the cluster description."""
if event["status"] != "success":
raise AirflowException(f"Error while waiting for operation on cluster to complete: {event}")
cluster_arn = event.get("arn")
# We cannot get the cluster definition from the waiter on success, so we have to query it here.
details = self.hook.conn.describe_clusters(clusters=[cluster_arn])["clusters"][0]
return details


class EcsCreateClusterOperator(EcsBaseOperator):
"""
Expand All @@ -84,18 +98,27 @@ class EcsCreateClusterOperator(EcsBaseOperator):
if not set then the default waiter value will be used.
:param waiter_max_attempts: The maximum number of attempts to be made,
if not set then the default waiter value will be used.
:param deferrable: If True, the operator will wait asynchronously for the job to complete.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
"""

template_fields: Sequence[str] = ("cluster_name", "create_cluster_kwargs", "wait_for_completion")
template_fields: Sequence[str] = (
"cluster_name",
"create_cluster_kwargs",
"wait_for_completion",
"deferrable",
)

def __init__(
self,
*,
cluster_name: str,
create_cluster_kwargs: dict | None = None,
wait_for_completion: bool = True,
waiter_delay: int | None = None,
waiter_max_attempts: int | None = None,
waiter_delay: int = 15,
waiter_max_attempts: int = 60,
deferrable: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -104,6 +127,7 @@ def __init__(
self.wait_for_completion = wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable

def execute(self, context: Context):
self.log.info(
Expand All @@ -119,6 +143,21 @@ def execute(self, context: Context):
# In some circumstances the ECS Cluster is created immediately,
# and there is no reason to wait for completion.
self.log.info("Cluster %r in state: %r.", self.cluster_name, cluster_state)
elif self.deferrable:
self.defer(
trigger=ClusterWaiterTrigger(
waiter_name="cluster_active",
cluster_arn=cluster_details["clusterArn"],
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region=self.region,
),
method_name="_complete_exec_with_cluster_desc",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
)
elif self.wait_for_completion:
waiter = self.hook.get_waiter("cluster_active")
waiter.wait(
Expand Down Expand Up @@ -148,24 +187,29 @@ class EcsDeleteClusterOperator(EcsBaseOperator):
if not set then the default waiter value will be used.
:param waiter_max_attempts: The maximum number of attempts to be made,
if not set then the default waiter value will be used.
:param deferrable: If True, the operator will wait asynchronously for the job to complete.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
"""

template_fields: Sequence[str] = ("cluster_name", "wait_for_completion")
template_fields: Sequence[str] = ("cluster_name", "wait_for_completion", "deferrable")

def __init__(
self,
*,
cluster_name: str,
wait_for_completion: bool = True,
waiter_delay: int | None = None,
waiter_max_attempts: int | None = None,
waiter_delay: int = 15,
waiter_max_attempts: int = 60,
deferrable: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.cluster_name = cluster_name
self.wait_for_completion = wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable

def execute(self, context: Context):
self.log.info("Deleting cluster %r.", self.cluster_name)
Expand All @@ -174,9 +218,24 @@ def execute(self, context: Context):
cluster_state = cluster_details.get("status")

if cluster_state == EcsClusterStates.INACTIVE:
# In some circumstances the ECS Cluster is deleted immediately,
# so there is no reason to wait for completion.
# if the cluster doesn't have capacity providers that are associated with it,
# the deletion is instantaneous, and we don't need to wait for it.
self.log.info("Cluster %r in state: %r.", self.cluster_name, cluster_state)
elif self.deferrable:
self.defer(
trigger=ClusterWaiterTrigger(
waiter_name="cluster_inactive",
cluster_arn=cluster_details["clusterArn"],
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region=self.region,
),
method_name="_complete_exec_with_cluster_desc",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
)
elif self.wait_for_completion:
waiter = self.hook.get_waiter("cluster_inactive")
waiter.wait(
Expand Down Expand Up @@ -347,6 +406,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
finished.
:param awslogs_fetch_interval: the interval that the ECS task log fetcher should wait
in between each Cloudwatch logs fetches.
If deferrable is set to True, that parameter is ignored and waiter_delay is used instead.
:param quota_retry: Config if and how to retry the launch of a new ECS task, to handle
transient errors.
:param reattach: If set to True, will check if the task previously launched by the task_instance
Expand All @@ -361,6 +421,9 @@ class EcsRunTaskOperator(EcsBaseOperator):
if not set then the default waiter value will be used.
:param waiter_max_attempts: The maximum number of attempts to be made,
if not set then the default waiter value will be used.
:param deferrable: If True, the operator will wait asynchronously for the job to complete.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
"""

ui_color = "#f0ede4"
Expand All @@ -384,6 +447,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
"reattach",
"number_logs_exception",
"wait_for_completion",
"deferrable",
)
template_fields_renderers = {
"overrides": "json",
Expand Down Expand Up @@ -416,8 +480,9 @@ def __init__(
reattach: bool = False,
number_logs_exception: int = 10,
wait_for_completion: bool = True,
waiter_delay: int | None = None,
waiter_max_attempts: int | None = None,
waiter_delay: int = 6,
waiter_max_attempts: int = 100,
deferrable: bool = False,
**kwargs,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -451,6 +516,7 @@ def __init__(
self.wait_for_completion = wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable

if self._aws_logs_enabled() and not self.wait_for_completion:
self.log.warning(
Expand All @@ -473,7 +539,35 @@ def execute(self, context, session=None):
if self.reattach:
self._try_reattach_task(context)

self._start_wait_check_task(context)
self._start_wait_task(context)

self._after_execution(session)

if self.do_xcom_push and self.task_log_fetcher:
return self.task_log_fetcher.get_last_log_message()
else:
return None

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error in task execution: {event}")
self.arn = event["task_arn"] # restore arn to its updated value, needed for next steps
self._after_execution()
if self._aws_logs_enabled():
# same behavior as non-deferrable mode, return last line of logs of the task.
logs_client = AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region).conn
one_log = logs_client.get_log_events(
logGroupName=self.awslogs_group,
logStreamName=self._get_logs_stream_name(),
startFromHead=False,
limit=1,
)
if len(one_log["events"]) > 0:
return one_log["events"][0]["message"]

@provide_session
def _after_execution(self, session=None):
self._check_success_task()

self.log.info("ECS Task has been successfully executed")

Expand All @@ -482,16 +576,29 @@ def execute(self, context, session=None):
# as we can't reattach it anymore
self._xcom_del(session, self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id))

if self.do_xcom_push and self.task_log_fetcher:
return self.task_log_fetcher.get_last_log_message()

return None

@AwsBaseHook.retry(should_retry_eni)
def _start_wait_check_task(self, context):
def _start_wait_task(self, context):
if not self.arn:
self._start_task(context)

if self.deferrable:
self.defer(
trigger=TaskDoneTrigger(
cluster=self.cluster,
task_arn=self.arn,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region=self.region,
log_group=self.awslogs_group,
log_stream=self._get_logs_stream_name(),
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
)

if not self.wait_for_completion:
return

Expand All @@ -508,8 +615,6 @@ def _start_wait_check_task(self, context):
else:
self._wait_for_task_ended()

self._check_success_task()

def _xcom_del(self, session, task_id):
session.query(XCom).filter(XCom.dag_id == self.dag_id, XCom.task_id == task_id).delete()

Expand Down Expand Up @@ -584,33 +689,34 @@ def _wait_for_task_ended(self) -> None:
waiter.wait(
cluster=self.cluster,
tasks=[self.arn],
WaiterConfig=prune_dict(
{
"Delay": self.waiter_delay,
"MaxAttempts": self.waiter_max_attempts,
}
),
WaiterConfig={
"Delay": self.waiter_delay,
"MaxAttempts": self.waiter_max_attempts,
},
)

return

def _aws_logs_enabled(self):
return self.awslogs_group and self.awslogs_stream_prefix

def _get_logs_stream_name(self) -> str:
return f"{self.awslogs_stream_prefix}/{self._get_ecs_task_id(self.arn)}"

def _get_task_log_fetcher(self) -> AwsTaskLogFetcher:
if not self.awslogs_group:
raise ValueError("must specify awslogs_group to fetch task logs")
log_stream_name = f"{self.awslogs_stream_prefix}/{self._get_ecs_task_id(self.arn)}"

return AwsTaskLogFetcher(
aws_conn_id=self.aws_conn_id,
region_name=self.awslogs_region,
log_group=self.awslogs_group,
log_stream_name=log_stream_name,
log_stream_name=self._get_logs_stream_name(),
fetch_interval=self.awslogs_fetch_interval,
logger=self.log,
)

@AwsBaseHook.retry(should_retry_eni)
def _check_success_task(self) -> None:
if not self.client or not self.arn:
return
Expand Down
Loading

0 comments on commit 415e076

Please sign in to comment.