Skip to content

Commit

Permalink
Skip served logs for non-running task try
Browse files Browse the repository at this point in the history
  • Loading branch information
Khrol committed Jul 12, 2023
1 parent fcbf159 commit 19be8f7
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
12 changes: 6 additions & 6 deletions airflow/utils/log/file_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,10 @@ def _read(
executor_messages: list[str] = []
executor_logs: list[str] = []
served_logs: list[str] = []
is_running = ti.try_number == try_number and ti.state in (
TaskInstanceState.RUNNING,
TaskInstanceState.DEFERRED,
)
with suppress(NotImplementedError):
remote_messages, remote_logs = self._read_remote_logs(ti, try_number, metadata)
messages_list.extend(remote_messages)
Expand All @@ -321,7 +325,7 @@ def _read(
worker_log_full_path = Path(self.local_base, worker_log_rel_path)
local_messages, local_logs = self._read_from_local(worker_log_full_path)
messages_list.extend(local_messages)
if ti.state in (TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED) and not executor_messages:
if is_running and not executor_messages:
served_messages, served_logs = self._read_from_logs_server(ti, worker_log_rel_path)
messages_list.extend(served_messages)
elif ti.state not in State.unfinished and not (local_logs or remote_logs):
Expand All @@ -341,15 +345,11 @@ def _read(
)
log_pos = len(logs)
messages = "".join([f"*** {x}\n" for x in messages_list])
end_of_log = ti.try_number != try_number or ti.state not in (
TaskInstanceState.RUNNING,
TaskInstanceState.DEFERRED,
)
if metadata and "log_pos" in metadata:
previous_chars = metadata["log_pos"]
logs = logs[previous_chars:] # Cut off previously passed log test as new tail
out_message = logs if "log_pos" in (metadata or {}) else messages + logs
return out_message, {"end_of_log": end_of_log, "log_pos": log_pos}
return out_message, {"end_of_log": not is_running, "log_pos": log_pos}

@staticmethod
def _get_pod_namespace(ti: TaskInstance):
Expand Down
16 changes: 13 additions & 3 deletions tests/utils/test_log_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def test__read_for_k8s_executor(self, mock_k8s_get_task_log, create_task_instanc

def test__read_for_celery_executor_fallbacks_to_worker(self, create_task_instance):
"""Test for executors which do not have `get_task_log` method, it fallbacks to reading
log from worker"""
log from worker. But it happens only for the latest try_number."""
executor_name = "CeleryExecutor"

ti = create_task_instance(
Expand All @@ -306,14 +306,24 @@ def test__read_for_celery_executor_fallbacks_to_worker(self, create_task_instanc
execution_date=DEFAULT_DATE,
)
ti.state = TaskInstanceState.RUNNING
ti.try_number = 2
with conf_vars({("core", "executor"): executor_name}):
fth = FileTaskHandler("")

fth._read_from_logs_server = mock.Mock()
fth._read_from_logs_server.return_value = ["this message"], ["this\nlog\ncontent"]
actual = fth._read(ti=ti, try_number=1)
actual = fth._read(ti=ti, try_number=2)
fth._read_from_logs_server.assert_called_once()
assert actual == ("*** this message\nthis\nlog\ncontent", {"end_of_log": True, "log_pos": 16})
assert actual == ("*** this message\nthis\nlog\ncontent", {"end_of_log": False, "log_pos": 16})

# Previous try_number is from remote logs without reaching worker server
fth._read_from_logs_server.reset_mock()
fth._read_remote_logs = mock.Mock()
fth._read_remote_logs.return_value = ["remote logs"], ["remote\nlog\ncontent"]
actual = fth._read(ti=ti, try_number=1)
fth._read_remote_logs.assert_called_once()
fth._read_from_logs_server.assert_not_called()
assert actual == ("*** remote logs\nremote\nlog\ncontent", {"end_of_log": True, "log_pos": 18})

@pytest.mark.parametrize(
"remote_logs, local_logs, served_logs_checked",
Expand Down

0 comments on commit 19be8f7

Please sign in to comment.