From b1bc1d632b47b09ce269ac26b49e8ab084b523ef Mon Sep 17 00:00:00 2001 From: adam Date: Fri, 18 Oct 2024 04:21:38 +0000 Subject: [PATCH] Add container_name and update awslogs_stream_prefix pattern --- providers/src/airflow/providers/amazon/aws/operators/ecs.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/providers/src/airflow/providers/amazon/aws/operators/ecs.py b/providers/src/airflow/providers/amazon/aws/operators/ecs.py index 6f2906f5ad6e2..51dde9f75a30d 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/ecs.py +++ b/providers/src/airflow/providers/amazon/aws/operators/ecs.py @@ -368,7 +368,7 @@ class EcsRunTaskOperator(EcsBaseOperator): If None, this is the same as the `region` parameter. If that is also None, this is the default AWS region based on your connection settings. :param awslogs_stream_prefix: the stream prefix that is used for the CloudWatch logs. - This is usually based on some custom name combined with the name of the container. + This should match the prefix specified in the log configuration of the task definition. Only required if you want logs to be shown in the Airflow UI after your job has finished. :param awslogs_fetch_interval: the interval that the ECS task log fetcher should wait @@ -481,6 +481,7 @@ def __init__( self.awslogs_region = self.region_name self.arn: str | None = None + self.container_name: str | None = None self._started_by: str | None = None self.retry_args = quota_retry @@ -624,6 +625,7 @@ def _start_task(self): self.log.info("ECS Task started: %s", response) self.arn = response["tasks"][0]["taskArn"] + self.container_name = response["tasks"][0]["containers"][0]["name"] self.log.info("ECS task ID is: %s", self._get_ecs_task_id(self.arn)) def _try_reattach_task(self, started_by: str): @@ -659,7 +661,7 @@ def _aws_logs_enabled(self): return self.awslogs_group and self.awslogs_stream_prefix def _get_logs_stream_name(self) -> str: - return f"{self.awslogs_stream_prefix}/{self._get_ecs_task_id(self.arn)}" + return f"{self.awslogs_stream_prefix}/{self.container_name}/{self._get_ecs_task_id(self.arn)}" def _get_task_log_fetcher(self) -> AwsTaskLogFetcher: if not self.awslogs_group: