Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deferrable mode for ECS operators #31881

Merged
merged 26 commits into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
f4ea7d9
add deferrable mode for ECS Create Cluster
vandonr-amz May 25, 2023
4dddc9a
execute task - easy part, without the logs
vandonr-amz Jun 5, 2023
a6e5c0e
add logs support to run task operator
vandonr-amz Jun 6, 2023
bca71da
add tests
vandonr-amz Jun 9, 2023
59ff8f4
add deferrable for delete cluster, by adapting the create cluster tri…
vandonr-amz Jun 13, 2023
5dc163b
add deferrable parameter
vandonr-amz Jun 13, 2023
71b9648
rearranging code around a bit
vandonr-amz Jun 13, 2023
ab0f65d
tests
vandonr-amz Jun 13, 2023
a421457
add dots in comments
vandonr-amz Jun 13, 2023
b2d1dae
fix test
vandonr-amz Jun 13, 2023
913daeb
Merge remote-tracking branch 'origin/main' into vandonr/deferrable
vandonr-amz Jun 13, 2023
305ea5c
add trigger to yaml
vandonr-amz Jun 13, 2023
eab95e0
return last line of logs
vandonr-amz Jun 13, 2023
02ec64e
is this the right integration name ?
vandonr-amz Jun 13, 2023
960d1c5
add timeouts
vandonr-amz Jun 14, 2023
7141b0e
fix CI + some fix
vandonr-amz Jun 14, 2023
e215315
Merge remote-tracking branch 'origin/main' into vandonr/deferrable
vandonr-amz Jun 15, 2023
2eb2252
adjust expected value in test
vandonr-amz Jun 15, 2023
8861815
Merge remote-tracking branch 'origin/main' into vandonr/deferrable
vandonr-amz Jun 16, 2023
6ecd019
fix
vandonr-amz Jun 16, 2023
3038246
Merge remote-tracking branch 'origin/main' into vandonr/deferrable
vandonr-amz Jun 16, 2023
76a9f39
Merge remote-tracking branch 'origin/main' into vandonr/deferrable
vandonr-amz Jun 20, 2023
439bbd9
rename method in test
vandonr-amz Jun 20, 2023
5beb6b2
Merge remote-tracking branch 'origin/main' into vandonr/deferrable
vandonr-amz Jun 22, 2023
c5d50ad
use newly available wait method
vandonr-amz Jun 22, 2023
4e65d26
fix test
vandonr-amz Jun 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion airflow/providers/amazon/aws/hooks/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def _get_log_events(self, skip: int = 0) -> Generator:
self.logger.warning("ConnectionClosedError on retrieving Cloudwatch log events", error)
yield from ()

def _event_to_str(self, event: dict) -> str:
@staticmethod
def _event_to_str(event: dict) -> str:
event_dt = datetime.utcfromtimestamp(event["timestamp"] / 1000.0)
formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
message = event["message"]
Expand Down
118 changes: 101 additions & 17 deletions airflow/providers/amazon/aws/operators/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
EcsTaskLogFetcher,
should_retry_eni,
)
from airflow.providers.amazon.aws.triggers.ecs import (
ClusterWaiterTrigger,
TaskDoneTrigger,
)
from airflow.utils.helpers import prune_dict
from airflow.utils.session import provide_session

Expand Down Expand Up @@ -67,6 +71,15 @@ def execute(self, context: Context):
"""Must overwrite in child classes."""
raise NotImplementedError("Please implement execute() in subclass")

def _complete_exec_with_cluster_desc(self, context, event=None):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this callback is shared between create and delete cluster operators, so I put it there. It felt like a better solution than copy-pasting it for both.

"""To be used as trigger callback for operators that return the cluster description"""
if event["status"] != "success":
raise AirflowException(f"Error while waiting for operation on cluster to complete: {event}")
cluster_arn = event.get("arn")
# We cannot get the cluster definition from the waiter on success, so we have to query it here.
details = self.hook.conn.describe_clusters(clusters=[cluster_arn])["clusters"][0]
return details


class EcsCreateClusterOperator(EcsBaseOperator):
"""
Expand All @@ -84,9 +97,17 @@ class EcsCreateClusterOperator(EcsBaseOperator):
if not set then the default waiter value will be used.
:param waiter_max_attempts: The maximum number of attempts to be made,
if not set then the default waiter value will be used.
:param deferrable: If True, the operator will wait asynchronously for the job to complete.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
"""

template_fields: Sequence[str] = ("cluster_name", "create_cluster_kwargs", "wait_for_completion")
template_fields: Sequence[str] = (
"cluster_name",
"create_cluster_kwargs",
"wait_for_completion",
"deferrable",
)

def __init__(
self,
Expand All @@ -96,6 +117,7 @@ def __init__(
wait_for_completion: bool = True,
waiter_delay: int | None = None,
waiter_max_attempts: int | None = None,
deferrable: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -104,6 +126,7 @@ def __init__(
self.wait_for_completion = wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable

def execute(self, context: Context):
self.log.info(
Expand All @@ -119,6 +142,18 @@ def execute(self, context: Context):
# In some circumstances the ECS Cluster is created immediately,
# and there is no reason to wait for completion.
self.log.info("Cluster %r in state: %r.", self.cluster_name, cluster_state)
elif self.deferrable:
self.defer(
trigger=ClusterWaiterTrigger(
waiter_name="cluster_active",
cluster_arn=cluster_details["clusterArn"],
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region=self.region,
),
method_name="_complete_exec_with_cluster_desc",
)
vandonr-amz marked this conversation as resolved.
Show resolved Hide resolved
elif self.wait_for_completion:
waiter = self.hook.get_waiter("cluster_active")
waiter.wait(
Expand Down Expand Up @@ -148,9 +183,12 @@ class EcsDeleteClusterOperator(EcsBaseOperator):
if not set then the default waiter value will be used.
:param waiter_max_attempts: The maximum number of attempts to be made,
if not set then the default waiter value will be used.
:param deferrable: If True, the operator will wait asynchronously for the job to complete.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
"""

template_fields: Sequence[str] = ("cluster_name", "wait_for_completion")
template_fields: Sequence[str] = ("cluster_name", "wait_for_completion", "deferrable")

def __init__(
self,
Expand All @@ -159,13 +197,15 @@ def __init__(
wait_for_completion: bool = True,
waiter_delay: int | None = None,
waiter_max_attempts: int | None = None,
deferrable: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.cluster_name = cluster_name
self.wait_for_completion = wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable

def execute(self, context: Context):
self.log.info("Deleting cluster %r.", self.cluster_name)
Expand All @@ -174,9 +214,21 @@ def execute(self, context: Context):
cluster_state = cluster_details.get("status")

if cluster_state == EcsClusterStates.INACTIVE:
# In some circumstances the ECS Cluster is deleted immediately,
# so there is no reason to wait for completion.
# if the cluster doesn't have capacity providers that are associated with it,
# the deletion is instantaneous, and we don't need to wait for it.
Copy link
Contributor

@syedahsn syedahsn Jun 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cluster_details has the capacityProviders associated with the nodegroup. Would that be a better way to decide whether we want to wait for completion or not?
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs/client/delete_cluster.html

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, we could do that, but the check on the status above is already taking care of that. We can write a different check, but the result would be the same.

self.log.info("Cluster %r in state: %r.", self.cluster_name, cluster_state)
elif self.deferrable:
self.defer(
trigger=ClusterWaiterTrigger(
waiter_name="cluster_inactive",
cluster_arn=cluster_details["clusterArn"],
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region=self.region,
),
method_name="_complete_exec_with_cluster_desc",
)
elif self.wait_for_completion:
waiter = self.hook.get_waiter("cluster_inactive")
waiter.wait(
Expand Down Expand Up @@ -375,6 +427,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
finished.
:param awslogs_fetch_interval: the interval that the ECS task log fetcher should wait
in between each Cloudwatch logs fetches.
If deferrable is set to True, that parameter is ignored and waiter_delay is used instead.
:param quota_retry: Config if and how to retry the launch of a new ECS task, to handle
transient errors.
:param reattach: If set to True, will check if the task previously launched by the task_instance
Expand All @@ -389,6 +442,9 @@ class EcsRunTaskOperator(EcsBaseOperator):
if not set then the default waiter value will be used.
:param waiter_max_attempts: The maximum number of attempts to be made,
if not set then the default waiter value will be used.
:param deferrable: If True, the operator will wait asynchronously for the job to complete.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
"""

ui_color = "#f0ede4"
Expand All @@ -412,6 +468,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
"reattach",
"number_logs_exception",
"wait_for_completion",
"deferrable",
)
template_fields_renderers = {
"overrides": "json",
Expand Down Expand Up @@ -446,6 +503,7 @@ def __init__(
wait_for_completion: bool = True,
waiter_delay: int | None = None,
waiter_max_attempts: int | None = None,
deferrable: bool = False,
**kwargs,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -479,6 +537,7 @@ def __init__(
self.wait_for_completion = wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable

@provide_session
def execute(self, context, session=None):
Expand All @@ -490,7 +549,26 @@ def execute(self, context, session=None):
if self.reattach:
self._try_reattach_task(context)

self._start_wait_check_task(context)
self._start_wait_task(context)

self._after_execution(session)

if self.do_xcom_push and self.task_log_fetcher:
return self.task_log_fetcher.get_last_log_message()
else:
return None

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error in task execution: {event}")
self.arn = event["task_arn"] # restore arn to its updated value
self._after_execution()
if self._aws_logs_enabled():
... # TODO return last log line but task_log_fetcher will always be None here

@provide_session
def _after_execution(self, session=None):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to extract this to reuse it in execute and execute_complete, but I wouldn't find a great name for it.

self._check_success_task()

self.log.info("ECS Task has been successfully executed")

Expand All @@ -499,18 +577,26 @@ def execute(self, context, session=None):
# as we can't reattach it anymore
self._xcom_del(session, self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id))

if self.do_xcom_push and self.task_log_fetcher:
return self.task_log_fetcher.get_last_log_message()

return None

@AwsBaseHook.retry(should_retry_eni)
def _start_wait_check_task(self, context):
def _start_wait_task(self, context):
Comment on lines -491 to +580
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the check went to _after_execution


if not self.arn:
self._start_task(context)

if self._aws_logs_enabled():
if self.deferrable:
self.defer(
trigger=TaskDoneTrigger(
cluster=self.cluster,
task_arn=self.arn,
waiter_delay=self.waiter_delay,
aws_conn_id=self.aws_conn_id,
region=self.region,
log_group=self.awslogs_group,
log_stream=f"{self.awslogs_stream_prefix}/{self.ecs_task_id}",
),
method_name="execute_complete",
)
elif self._aws_logs_enabled():
self.log.info("Starting ECS Task Log Fetcher")
self.task_log_fetcher = self._get_task_log_fetcher()
self.task_log_fetcher.start()
Expand All @@ -522,11 +608,8 @@ def _start_wait_check_task(self, context):
self.task_log_fetcher.stop()

self.task_log_fetcher.join()
else:
if self.wait_for_completion:
self._wait_for_task_ended()

self._check_success_task()
elif self.wait_for_completion:
self._wait_for_task_ended()

def _xcom_del(self, session, task_id):
session.query(XCom).filter(XCom.dag_id == self.dag_id, XCom.task_id == task_id).delete()
Expand Down Expand Up @@ -631,6 +714,7 @@ def _get_task_log_fetcher(self) -> EcsTaskLogFetcher:
logger=self.log,
)

@AwsBaseHook.retry(should_retry_eni)
def _check_success_task(self) -> None:
if not self.client or not self.arn:
return
Expand Down
Loading