-
Notifications
You must be signed in to change notification settings - Fork 14.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deferrable mode for ECS operators #31881
Changes from all commits
f4ea7d9
4dddc9a
a6e5c0e
bca71da
59ff8f4
5dc163b
71b9648
ab0f65d
a421457
b2d1dae
913daeb
305ea5c
eab95e0
02ec64e
960d1c5
7141b0e
e215315
2eb2252
8861815
6ecd019
3038246
76a9f39
439bbd9
5beb6b2
c5d50ad
4e65d26
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,11 @@ | |
EcsHook, | ||
should_retry_eni, | ||
) | ||
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook | ||
from airflow.providers.amazon.aws.triggers.ecs import ( | ||
ClusterWaiterTrigger, | ||
TaskDoneTrigger, | ||
) | ||
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher | ||
from airflow.utils.helpers import prune_dict | ||
from airflow.utils.session import provide_session | ||
|
@@ -67,6 +72,15 @@ def execute(self, context: Context): | |
"""Must overwrite in child classes.""" | ||
raise NotImplementedError("Please implement execute() in subclass") | ||
|
||
def _complete_exec_with_cluster_desc(self, context, event=None): | ||
"""To be used as trigger callback for operators that return the cluster description.""" | ||
if event["status"] != "success": | ||
raise AirflowException(f"Error while waiting for operation on cluster to complete: {event}") | ||
cluster_arn = event.get("arn") | ||
# We cannot get the cluster definition from the waiter on success, so we have to query it here. | ||
details = self.hook.conn.describe_clusters(clusters=[cluster_arn])["clusters"][0] | ||
return details | ||
|
||
|
||
class EcsCreateClusterOperator(EcsBaseOperator): | ||
""" | ||
|
@@ -84,18 +98,27 @@ class EcsCreateClusterOperator(EcsBaseOperator): | |
if not set then the default waiter value will be used. | ||
:param waiter_max_attempts: The maximum number of attempts to be made, | ||
if not set then the default waiter value will be used. | ||
:param deferrable: If True, the operator will wait asynchronously for the job to complete. | ||
This implies waiting for completion. This mode requires aiobotocore module to be installed. | ||
(default: False) | ||
""" | ||
|
||
template_fields: Sequence[str] = ("cluster_name", "create_cluster_kwargs", "wait_for_completion") | ||
template_fields: Sequence[str] = ( | ||
"cluster_name", | ||
"create_cluster_kwargs", | ||
"wait_for_completion", | ||
"deferrable", | ||
) | ||
|
||
def __init__( | ||
self, | ||
*, | ||
cluster_name: str, | ||
create_cluster_kwargs: dict | None = None, | ||
wait_for_completion: bool = True, | ||
waiter_delay: int | None = None, | ||
waiter_max_attempts: int | None = None, | ||
waiter_delay: int = 15, | ||
waiter_max_attempts: int = 60, | ||
deferrable: bool = False, | ||
**kwargs, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
|
@@ -104,6 +127,7 @@ def __init__( | |
self.wait_for_completion = wait_for_completion | ||
self.waiter_delay = waiter_delay | ||
self.waiter_max_attempts = waiter_max_attempts | ||
self.deferrable = deferrable | ||
|
||
def execute(self, context: Context): | ||
self.log.info( | ||
|
@@ -119,6 +143,21 @@ def execute(self, context: Context): | |
# In some circumstances the ECS Cluster is created immediately, | ||
# and there is no reason to wait for completion. | ||
self.log.info("Cluster %r in state: %r.", self.cluster_name, cluster_state) | ||
elif self.deferrable: | ||
self.defer( | ||
trigger=ClusterWaiterTrigger( | ||
waiter_name="cluster_active", | ||
cluster_arn=cluster_details["clusterArn"], | ||
waiter_delay=self.waiter_delay, | ||
waiter_max_attempts=self.waiter_max_attempts, | ||
aws_conn_id=self.aws_conn_id, | ||
region=self.region, | ||
), | ||
method_name="_complete_exec_with_cluster_desc", | ||
# timeout is set to ensure that if a trigger dies, the timeout does not restart | ||
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent) | ||
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60), | ||
) | ||
vandonr-amz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
elif self.wait_for_completion: | ||
waiter = self.hook.get_waiter("cluster_active") | ||
waiter.wait( | ||
|
@@ -148,24 +187,29 @@ class EcsDeleteClusterOperator(EcsBaseOperator): | |
if not set then the default waiter value will be used. | ||
:param waiter_max_attempts: The maximum number of attempts to be made, | ||
if not set then the default waiter value will be used. | ||
:param deferrable: If True, the operator will wait asynchronously for the job to complete. | ||
This implies waiting for completion. This mode requires aiobotocore module to be installed. | ||
(default: False) | ||
""" | ||
|
||
template_fields: Sequence[str] = ("cluster_name", "wait_for_completion") | ||
template_fields: Sequence[str] = ("cluster_name", "wait_for_completion", "deferrable") | ||
|
||
def __init__( | ||
self, | ||
*, | ||
cluster_name: str, | ||
wait_for_completion: bool = True, | ||
waiter_delay: int | None = None, | ||
waiter_max_attempts: int | None = None, | ||
waiter_delay: int = 15, | ||
waiter_max_attempts: int = 60, | ||
deferrable: bool = False, | ||
**kwargs, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
self.cluster_name = cluster_name | ||
self.wait_for_completion = wait_for_completion | ||
self.waiter_delay = waiter_delay | ||
self.waiter_max_attempts = waiter_max_attempts | ||
self.deferrable = deferrable | ||
|
||
def execute(self, context: Context): | ||
self.log.info("Deleting cluster %r.", self.cluster_name) | ||
|
@@ -174,9 +218,24 @@ def execute(self, context: Context): | |
cluster_state = cluster_details.get("status") | ||
|
||
if cluster_state == EcsClusterStates.INACTIVE: | ||
# In some circumstances the ECS Cluster is deleted immediately, | ||
# so there is no reason to wait for completion. | ||
# if the cluster doesn't have capacity providers that are associated with it, | ||
# the deletion is instantaneous, and we don't need to wait for it. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, we could do that, but the check on the status above is already taking care of that. We can write a different check, but the result would be the same. |
||
self.log.info("Cluster %r in state: %r.", self.cluster_name, cluster_state) | ||
elif self.deferrable: | ||
self.defer( | ||
trigger=ClusterWaiterTrigger( | ||
waiter_name="cluster_inactive", | ||
cluster_arn=cluster_details["clusterArn"], | ||
waiter_delay=self.waiter_delay, | ||
waiter_max_attempts=self.waiter_max_attempts, | ||
aws_conn_id=self.aws_conn_id, | ||
region=self.region, | ||
), | ||
method_name="_complete_exec_with_cluster_desc", | ||
# timeout is set to ensure that if a trigger dies, the timeout does not restart | ||
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent) | ||
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60), | ||
) | ||
elif self.wait_for_completion: | ||
waiter = self.hook.get_waiter("cluster_inactive") | ||
waiter.wait( | ||
|
@@ -347,6 +406,7 @@ class EcsRunTaskOperator(EcsBaseOperator): | |
finished. | ||
:param awslogs_fetch_interval: the interval that the ECS task log fetcher should wait | ||
in between each Cloudwatch logs fetches. | ||
If deferrable is set to True, that parameter is ignored and waiter_delay is used instead. | ||
:param quota_retry: Config if and how to retry the launch of a new ECS task, to handle | ||
transient errors. | ||
:param reattach: If set to True, will check if the task previously launched by the task_instance | ||
|
@@ -361,6 +421,9 @@ class EcsRunTaskOperator(EcsBaseOperator): | |
if not set then the default waiter value will be used. | ||
:param waiter_max_attempts: The maximum number of attempts to be made, | ||
if not set then the default waiter value will be used. | ||
:param deferrable: If True, the operator will wait asynchronously for the job to complete. | ||
This implies waiting for completion. This mode requires aiobotocore module to be installed. | ||
(default: False) | ||
""" | ||
|
||
ui_color = "#f0ede4" | ||
|
@@ -384,6 +447,7 @@ class EcsRunTaskOperator(EcsBaseOperator): | |
"reattach", | ||
"number_logs_exception", | ||
"wait_for_completion", | ||
"deferrable", | ||
) | ||
template_fields_renderers = { | ||
"overrides": "json", | ||
|
@@ -416,8 +480,9 @@ def __init__( | |
reattach: bool = False, | ||
number_logs_exception: int = 10, | ||
wait_for_completion: bool = True, | ||
waiter_delay: int | None = None, | ||
waiter_max_attempts: int | None = None, | ||
waiter_delay: int = 6, | ||
waiter_max_attempts: int = 100, | ||
deferrable: bool = False, | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
|
@@ -451,6 +516,7 @@ def __init__( | |
self.wait_for_completion = wait_for_completion | ||
self.waiter_delay = waiter_delay | ||
self.waiter_max_attempts = waiter_max_attempts | ||
self.deferrable = deferrable | ||
|
||
if self._aws_logs_enabled() and not self.wait_for_completion: | ||
self.log.warning( | ||
|
@@ -473,7 +539,35 @@ def execute(self, context, session=None): | |
if self.reattach: | ||
self._try_reattach_task(context) | ||
|
||
self._start_wait_check_task(context) | ||
self._start_wait_task(context) | ||
|
||
self._after_execution(session) | ||
|
||
if self.do_xcom_push and self.task_log_fetcher: | ||
return self.task_log_fetcher.get_last_log_message() | ||
else: | ||
return None | ||
|
||
def execute_complete(self, context, event=None): | ||
if event["status"] != "success": | ||
raise AirflowException(f"Error in task execution: {event}") | ||
self.arn = event["task_arn"] # restore arn to its updated value, needed for next steps | ||
self._after_execution() | ||
if self._aws_logs_enabled(): | ||
# same behavior as non-deferrable mode, return last line of logs of the task. | ||
logs_client = AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region).conn | ||
one_log = logs_client.get_log_events( | ||
logGroupName=self.awslogs_group, | ||
logStreamName=self._get_logs_stream_name(), | ||
startFromHead=False, | ||
limit=1, | ||
) | ||
if len(one_log["events"]) > 0: | ||
return one_log["events"][0]["message"] | ||
|
||
@provide_session | ||
def _after_execution(self, session=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wanted to extract this to reuse it in |
||
self._check_success_task() | ||
|
||
self.log.info("ECS Task has been successfully executed") | ||
|
||
|
@@ -482,16 +576,29 @@ def execute(self, context, session=None): | |
# as we can't reattach it anymore | ||
self._xcom_del(session, self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id)) | ||
|
||
if self.do_xcom_push and self.task_log_fetcher: | ||
return self.task_log_fetcher.get_last_log_message() | ||
|
||
return None | ||
|
||
@AwsBaseHook.retry(should_retry_eni) | ||
def _start_wait_check_task(self, context): | ||
def _start_wait_task(self, context): | ||
Comment on lines
-491
to
+580
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the check went to |
||
if not self.arn: | ||
self._start_task(context) | ||
|
||
if self.deferrable: | ||
self.defer( | ||
trigger=TaskDoneTrigger( | ||
cluster=self.cluster, | ||
task_arn=self.arn, | ||
waiter_delay=self.waiter_delay, | ||
waiter_max_attempts=self.waiter_max_attempts, | ||
aws_conn_id=self.aws_conn_id, | ||
region=self.region, | ||
log_group=self.awslogs_group, | ||
log_stream=self._get_logs_stream_name(), | ||
), | ||
method_name="execute_complete", | ||
# timeout is set to ensure that if a trigger dies, the timeout does not restart | ||
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent) | ||
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60), | ||
) | ||
|
||
if not self.wait_for_completion: | ||
return | ||
|
||
|
@@ -508,8 +615,6 @@ def _start_wait_check_task(self, context): | |
else: | ||
self._wait_for_task_ended() | ||
|
||
self._check_success_task() | ||
|
||
def _xcom_del(self, session, task_id): | ||
session.query(XCom).filter(XCom.dag_id == self.dag_id, XCom.task_id == task_id).delete() | ||
|
||
|
@@ -584,33 +689,34 @@ def _wait_for_task_ended(self) -> None: | |
waiter.wait( | ||
cluster=self.cluster, | ||
tasks=[self.arn], | ||
WaiterConfig=prune_dict( | ||
{ | ||
"Delay": self.waiter_delay, | ||
"MaxAttempts": self.waiter_max_attempts, | ||
} | ||
), | ||
WaiterConfig={ | ||
"Delay": self.waiter_delay, | ||
"MaxAttempts": self.waiter_max_attempts, | ||
}, | ||
) | ||
|
||
return | ||
|
||
def _aws_logs_enabled(self): | ||
return self.awslogs_group and self.awslogs_stream_prefix | ||
|
||
def _get_logs_stream_name(self) -> str: | ||
return f"{self.awslogs_stream_prefix}/{self._get_ecs_task_id(self.arn)}" | ||
|
||
def _get_task_log_fetcher(self) -> AwsTaskLogFetcher: | ||
if not self.awslogs_group: | ||
raise ValueError("must specify awslogs_group to fetch task logs") | ||
log_stream_name = f"{self.awslogs_stream_prefix}/{self._get_ecs_task_id(self.arn)}" | ||
|
||
return AwsTaskLogFetcher( | ||
aws_conn_id=self.aws_conn_id, | ||
region_name=self.awslogs_region, | ||
log_group=self.awslogs_group, | ||
log_stream_name=log_stream_name, | ||
log_stream_name=self._get_logs_stream_name(), | ||
fetch_interval=self.awslogs_fetch_interval, | ||
logger=self.log, | ||
) | ||
|
||
@AwsBaseHook.retry(should_retry_eni) | ||
def _check_success_task(self) -> None: | ||
if not self.client or not self.arn: | ||
return | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this callback is shared between create and delete cluster operators, so I put it there. It felt like a better solution than copy-pasting it for both.