Skip to content

Commit

Permalink
Improve fetching logs from AWS (#33231)
Browse files Browse the repository at this point in the history
Make use of start and end time input to the cloudwatch API to reduce the log search space and speed up log retrieval.

GitOrigin-RevId: c14cb85f16b6c9befd35866327fecb4ab9bc0fc4
  • Loading branch information
vincbeck authored and Cloud Composer Team committed Nov 8, 2024
1 parent cc8729b commit ff2ec3f
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 18 deletions.
21 changes: 15 additions & 6 deletions airflow/providers/amazon/aws/hooks/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.utils.helpers import prune_dict

# Guidance received from the AWS team regarding the correct way to check for the end of a stream is that the
# value of the nextForwardToken is the same in subsequent calls.
Expand Down Expand Up @@ -60,6 +61,7 @@ def get_log_events(
log_group: str,
log_stream_name: str,
start_time: int = 0,
end_time: int | None = None,
skip: int = 0,
start_from_head: bool | None = None,
continuation_token: ContinuationToken | None = None,
Expand All @@ -72,7 +74,9 @@ def get_log_events(
:param log_group: The name of the log group.
:param log_stream_name: The name of the specific stream.
:param start_time: The time stamp value to start reading the logs from (default: 0).
:param start_time: The timestamp value in ms to start reading the logs from (default: 0).
:param end_time: The timestamp value in ms to stop reading the logs from (default: None).
If None is provided, reads it until the end of the log stream
:param skip: The number of log entries to skip at the start (default: 0).
This is for when there are multiple entries at the same timestamp.
:param start_from_head: Deprecated. Do not use with False, logs would be retrieved out of order.
Expand Down Expand Up @@ -110,11 +114,16 @@ def get_log_events(
token_arg = {}

response = self.conn.get_log_events(
logGroupName=log_group,
logStreamName=log_stream_name,
startTime=start_time,
startFromHead=start_from_head,
**token_arg,
**prune_dict(
{
"logGroupName": log_group,
"logStreamName": log_stream_name,
"startTime": start_time,
"endTime": end_time,
"startFromHead": start_from_head,
**token_arg,
}
)
)

events = response["events"]
Expand Down
23 changes: 19 additions & 4 deletions airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
# under the License.
from __future__ import annotations

from datetime import datetime
from datetime import datetime, timedelta
from functools import cached_property

import watchtower

from airflow.configuration import conf
from airflow.models import TaskInstance
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.utils import datetime_to_epoch_utc_ms
from airflow.utils.log.file_task_handler import FileTaskHandler
from airflow.utils.log.logging_mixin import LoggingMixin

Expand Down Expand Up @@ -90,7 +92,8 @@ def _read(self, task_instance, try_number, metadata=None):
try:
return (
f"*** Reading remote log from Cloudwatch log_group: {self.log_group} "
f"log_stream: {stream_name}.\n{self.get_cloudwatch_logs(stream_name=stream_name)}\n",
f"log_stream: {stream_name}.\n"
f"{self.get_cloudwatch_logs(stream_name=stream_name, task_instance=task_instance)}\n",
{"end_of_log": True},
)
except Exception as e:
Expand All @@ -103,17 +106,29 @@ def _read(self, task_instance, try_number, metadata=None):
log += local_log
return log, metadata

def get_cloudwatch_logs(self, stream_name: str) -> str:
def get_cloudwatch_logs(self, stream_name: str, task_instance: TaskInstance) -> str:
"""
Return all logs from the given log stream.
:param stream_name: name of the Cloudwatch log stream to get all logs from
:param task_instance: the task instance to get logs about
:return: string of all logs from the given log stream
"""
start_time = (
0 if task_instance.start_date is None else datetime_to_epoch_utc_ms(task_instance.start_date)
)
# If there is an end_date to the task instance, fetch logs until that date + 30 seconds
# 30 seconds is an arbitrary buffer so that we don't miss any logs that were emitted
end_time = (
None
if task_instance.end_date is None
else datetime_to_epoch_utc_ms(task_instance.end_date + timedelta(seconds=30))
)
events = self.hook.get_log_events(
log_group=self.log_group,
log_stream_name=stream_name,
start_from_head=True,
start_time=start_time,
end_time=end_time,
)
return "\n".join(self._event_to_str(event) for event in events)

Expand Down
7 changes: 6 additions & 1 deletion airflow/providers/amazon/aws/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import logging
import re
from datetime import datetime
from datetime import datetime, timezone
from enum import Enum

from airflow.utils.helpers import prune_dict
Expand Down Expand Up @@ -55,6 +55,11 @@ def datetime_to_epoch_ms(date_time: datetime) -> int:
return int(date_time.timestamp() * 1_000)


def datetime_to_epoch_utc_ms(date_time: datetime) -> int:
"""Convert a datetime object to an epoch integer (milliseconds) in UTC timezone."""
return int(date_time.replace(tzinfo=timezone.utc).timestamp() * 1_000)


def datetime_to_epoch_us(date_time: datetime) -> int:
"""Convert a datetime object to an epoch integer (microseconds)."""
return int(date_time.timestamp() * 1_000_000)
Expand Down
29 changes: 24 additions & 5 deletions tests/providers/amazon/aws/hooks/test_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

from unittest import mock
from unittest.mock import patch
from unittest.mock import ANY, patch

import pytest
from moto import mock_logs
Expand All @@ -29,7 +29,7 @@
@mock_logs
class TestAwsLogsHook:
@pytest.mark.parametrize(
"get_log_events_response, num_skip_events, expected_num_events",
"get_log_events_response, num_skip_events, expected_num_events, end_time",
[
# 3 empty responses with different tokens
(
Expand All @@ -40,6 +40,7 @@ class TestAwsLogsHook:
],
0,
0,
None,
),
# 2 events on the second response with same token
(
Expand All @@ -49,6 +50,7 @@ class TestAwsLogsHook:
],
0,
2,
None,
),
# Different tokens, 2 events on the second response then 3 empty responses
(
Expand All @@ -63,6 +65,7 @@ class TestAwsLogsHook:
],
0,
2,
10,
),
# 2 events on the second response, then 2 empty responses, then 2 consecutive responses with
# 2 events with the same token
Expand All @@ -79,20 +82,36 @@ class TestAwsLogsHook:
],
0,
6,
20,
),
],
)
@patch("airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.conn", new_callable=mock.PropertyMock)
def test_get_log_events(self, mock_conn, get_log_events_response, num_skip_events, expected_num_events):
def test_get_log_events(
self, mock_conn, get_log_events_response, num_skip_events, expected_num_events, end_time
):
mock_conn().get_log_events.side_effect = get_log_events_response

log_group_name = "example-group"
log_stream_name = "example-log-stream"
hook = AwsLogsHook(aws_conn_id="aws_default", region_name="us-east-1")
events = hook.get_log_events(
log_group="example-group",
log_stream_name="example-log-stream",
log_group=log_group_name,
log_stream_name=log_stream_name,
skip=num_skip_events,
end_time=end_time,
)

events = list(events)

assert len(events) == expected_num_events
kwargs = {
"logGroupName": log_group_name,
"logStreamName": log_stream_name,
"startFromHead": True,
"startTime": 0,
"nextToken": ANY,
}
if end_time:
kwargs["endTime"] = end_time
mock_conn().get_log_events.assert_called_with(**kwargs)
37 changes: 35 additions & 2 deletions tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import time
from datetime import datetime as dt
from datetime import datetime as dt, timedelta
from unittest import mock
from unittest.mock import call

Expand All @@ -31,6 +31,7 @@
from airflow.operators.empty import EmptyOperator
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.log.cloudwatch_task_handler import CloudwatchTaskHandler
from airflow.providers.amazon.aws.utils import datetime_to_epoch_utc_ms
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.timezone import datetime
Expand Down Expand Up @@ -60,7 +61,6 @@ def setup_tests(self, create_log_template, tmp_path_factory):
self.local_log_location,
f"arn:aws:logs:{self.region_name}:11111111:log-group:{self.remote_log_group}",
)
self.cloudwatch_task_handler.hook

date = datetime(2020, 1, 1)
dag_id = "dag_for_testing_cloudwatch_task_handler"
Expand Down Expand Up @@ -154,6 +154,39 @@ def test_read(self):
[{"end_of_log": True}],
)

@pytest.mark.parametrize(
"start_date, end_date, expected_start_time, expected_end_time",
[
(None, None, 0, None),
(datetime(2020, 1, 1), None, datetime_to_epoch_utc_ms(datetime(2020, 1, 1)), None),
(
None,
datetime(2020, 1, 2),
0,
datetime_to_epoch_utc_ms(datetime(2020, 1, 2) + timedelta(seconds=30)),
),
(
datetime(2020, 1, 1),
datetime(2020, 1, 2),
datetime_to_epoch_utc_ms(datetime(2020, 1, 1)),
datetime_to_epoch_utc_ms(datetime(2020, 1, 2) + timedelta(seconds=30)),
),
],
)
@mock.patch.object(AwsLogsHook, "get_log_events")
def test_get_cloudwatch_logs(
self, mock_get_log_events, start_date, end_date, expected_start_time, expected_end_time
):
self.ti.start_date = start_date
self.ti.end_date = end_date
self.cloudwatch_task_handler.get_cloudwatch_logs(self.remote_log_stream, self.ti)
mock_get_log_events.assert_called_once_with(
log_group=self.remote_log_group,
log_stream_name=self.remote_log_stream,
start_time=expected_start_time,
end_time=expected_end_time,
)

def test_close_prevents_duplicate_calls(self):
with mock.patch("watchtower.CloudWatchLogHandler.close") as mock_log_handler_close:
with mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.set_context"):
Expand Down

0 comments on commit ff2ec3f

Please sign in to comment.