diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py index 13bf99d728d31..a8f68d724efb4 100644 --- a/kubernetes_tests/test_kubernetes_pod_operator.py +++ b/kubernetes_tests/test_kubernetes_pod_operator.py @@ -1326,6 +1326,76 @@ class MyK8SPodOperator(KubernetesPodOperator): ) assert MyK8SPodOperator(task_id=str(uuid4())).base_container_name == "tomato-sauce" + def test_init_container_logs(self, mock_get_connection): + marker_from_init_container = f"{uuid4()}" + marker_from_main_container = f"{uuid4()}" + progress_callback = MagicMock() + init_container = k8s.V1Container( + name="init-container", + image="busybox", + command=["sh", "-cx"], + args=[f"echo {marker_from_init_container}"], + ) + k = KubernetesPodOperator( + namespace="default", + image="busybox", + cmds=["sh", "-cx"], + arguments=[f"echo {marker_from_main_container}"], + labels=self.labels, + task_id=str(uuid4()), + in_cluster=False, + do_xcom_push=False, + startup_timeout_seconds=5, + init_containers=[init_container], + init_container_logs=True, + progress_callback=progress_callback, + ) + context = create_context(k) + k.execute(context) + + calls = progress_callback.call_args_list + assert any(marker_from_init_container in "".join(c.args) for c in calls) + assert any(marker_from_main_container in "".join(c.args) for c in calls) + + def test_init_container_logs_filtered(self, mock_get_connection): + marker_from_init_container_to_log = f"{uuid4()}" + marker_from_init_container_to_ignore = f"{uuid4()}" + marker_from_main_container = f"{uuid4()}" + progress_callback = MagicMock() + init_container_to_log = k8s.V1Container( + name="init-container-to-log", + image="busybox", + command=["sh", "-cx"], + args=[f"echo {marker_from_init_container_to_log}"], + ) + init_container_to_ignore = k8s.V1Container( + name="init-container-to-ignore", + image="busybox", + command=["sh", "-cx"], + args=[f"echo {marker_from_init_container_to_ignore}"], + ) + k = KubernetesPodOperator( + namespace="default", + image="busybox", + cmds=["sh", "-cx"], + arguments=[f"echo {marker_from_main_container}"], + labels=self.labels, + task_id=str(uuid4()), + in_cluster=False, + do_xcom_push=False, + startup_timeout_seconds=5, + init_containers=[init_container_to_log, init_container_to_ignore], + init_container_logs=["init-container-to-log"], + progress_callback=progress_callback, + ) + context = create_context(k) + k.execute(context) + + calls = progress_callback.call_args_list + assert any(marker_from_init_container_to_log in "".join(c.args) for c in calls) + assert not any(marker_from_init_container_to_ignore in "".join(c.args) for c in calls) + assert any(marker_from_main_container in "".join(c.args) for c in calls) + def test_hide_sensitive_field_in_templated_fields_on_error(caplog, monkeypatch): logger = logging.getLogger("airflow.task") diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py index 62f08439d4160..9dadf8844707d 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py @@ -158,6 +158,9 @@ class KubernetesPodOperator(BaseOperator): :param startup_timeout_seconds: timeout in seconds to startup the pod. :param startup_check_interval_seconds: interval in seconds to check if the pod has already started :param get_logs: get the stdout of the base container as logs of the tasks. + :param init_container_logs: list of init containers whose logs will be published to stdout + Takes a sequence of containers, a single container name or True. If True, + all the containers logs are published. :param container_logs: list of containers whose logs will be published to stdout Takes a sequence of containers, a single container name or True. If True, all the containers logs are published. Works in conjunction with get_logs param. @@ -287,6 +290,7 @@ def __init__( startup_check_interval_seconds: int = 5, get_logs: bool = True, base_container_name: str | None = None, + init_container_logs: Iterable[str] | str | Literal[True] | None = None, container_logs: Iterable[str] | str | Literal[True] | None = None, image_pull_policy: str | None = None, annotations: dict | None = None, @@ -361,6 +365,7 @@ def __init__( # Fallback to the class variable BASE_CONTAINER_NAME here instead of via default argument value # in the init method signature, to be compatible with subclasses overloading the class variable value. self.base_container_name = base_container_name or self.BASE_CONTAINER_NAME + self.init_container_logs = init_container_logs self.container_logs = container_logs or self.base_container_name self.image_pull_policy = image_pull_policy self.node_selector = node_selector or {} @@ -620,6 +625,9 @@ def execute_sync(self, context: Context): self.callbacks.on_pod_creation( pod=self.remote_pod, client=self.client, mode=ExecutionMode.SYNC ) + + self.await_init_containers_completion(pod=self.pod) + self.await_pod_start(pod=self.pod) if self.callbacks: self.callbacks.on_pod_starting( @@ -655,6 +663,31 @@ def execute_sync(self, context: Context): if self.do_xcom_push: return result + @tenacity.retry( + wait=tenacity.wait_exponential(max=15), + retry=tenacity.retry_if_exception_type(PodCredentialsExpiredFailure), + reraise=True, + ) + def await_init_containers_completion(self, pod: k8s.V1Pod): + try: + if self.init_container_logs: + self.pod_manager.fetch_requested_init_container_logs( + pod=pod, + init_containers=self.init_container_logs, + follow_logs=True, + ) + except kubernetes.client.exceptions.ApiException as exc: + if exc.status and str(exc.status) == "401": + self.log.warning( + "Failed to check container status due to permission error. Refreshing credentials and retrying." + ) + self._refresh_cached_properties() + self.pod_manager.read_pod( + pod=pod + ) # attempt using refreshed credentials, raises if still invalid + raise PodCredentialsExpiredFailure("Kubernetes credentials expired, retrying after refresh.") + raise exc + @tenacity.retry( wait=tenacity.wait_exponential(max=15), retry=tenacity.retry_if_exception_type(PodCredentialsExpiredFailure), diff --git a/providers/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py b/providers/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py index cd91dc09281f2..7dd8f7d023a97 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py +++ b/providers/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py @@ -19,6 +19,7 @@ from __future__ import annotations import enum +import itertools import json import math import time @@ -118,7 +119,13 @@ def get_xcom_sidecar_container_resources(self) -> str | None: def get_container_status(pod: V1Pod, container_name: str) -> V1ContainerStatus | None: """Retrieve container status.""" - container_statuses = pod.status.container_statuses if pod and pod.status else None + if pod and pod.status: + container_statuses = itertools.chain( + pod.status.container_statuses, pod.status.init_container_statuses + ) + else: + container_statuses = None + if container_statuses: # In general the variable container_statuses can store multiple items matching different containers. # The following generator expression yields all items that have name equal to the container_name. @@ -565,6 +572,28 @@ def _reconcile_requested_log_containers( self.log.error("Could not retrieve containers for the pod: %s", pod_name) return containers_to_log + def fetch_requested_init_container_logs( + self, pod: V1Pod, init_containers: Iterable[str] | str | Literal[True] | None, follow_logs=False + ) -> list[PodLoggingStatus]: + """ + Follow the logs of containers in the specified pod and publish it to airflow logging. + + Returns when all the containers exit. + + :meta private: + """ + pod_logging_statuses = [] + all_containers = self.get_init_container_names(pod) + containers_to_log = self._reconcile_requested_log_containers( + requested=init_containers, + actual=all_containers, + pod_name=pod.metadata.name, + ) + for c in containers_to_log: + status = self.fetch_container_logs(pod=pod, container_name=c, follow=follow_logs) + pod_logging_statuses.append(status) + return pod_logging_statuses + def fetch_requested_container_logs( self, pod: V1Pod, containers: Iterable[str] | str | Literal[True], follow_logs=False ) -> list[PodLoggingStatus]: @@ -692,6 +721,12 @@ def read_pod_logs( post_termination_timeout=post_termination_timeout, ) + @tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True) + def get_init_container_names(self, pod: V1Pod) -> list[str]: + """Return container names from the POD except for the airflow-xcom-sidecar container.""" + pod_info = self.read_pod(pod) + return [container_spec.name for container_spec in pod_info.spec.init_containers] + @tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True) def get_container_names(self, pod: V1Pod) -> list[str]: """Return container names from the POD except for the airflow-xcom-sidecar container.""" diff --git a/providers/tests/cncf/kubernetes/utils/test_pod_manager.py b/providers/tests/cncf/kubernetes/utils/test_pod_manager.py index b577ea969ea3d..62a8065ed2d5c 100644 --- a/providers/tests/cncf/kubernetes/utils/test_pod_manager.py +++ b/providers/tests/cncf/kubernetes/utils/test_pod_manager.py @@ -639,6 +639,7 @@ def remote_pod(running=None, not_running=None): e = RemotePodMock() e.status = RemotePodMock() e.status.container_statuses = [] + e.status.init_container_statuses = [] for r in not_running or []: e.status.container_statuses.append(container(r, False)) for r in running or []: