-
Notifications
You must be signed in to change notification settings - Fork 14.5k
/
pod.py
266 lines (240 loc) · 11.2 KB
/
pod.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import asyncio
import warnings
from asyncio import CancelledError
from datetime import datetime
from enum import Enum
from typing import Any, AsyncIterator
import pytz
from kubernetes_asyncio.client.models import V1Pod
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.cncf.kubernetes.hooks.kubernetes import AsyncKubernetesHook
from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction, PodPhase
from airflow.triggers.base import BaseTrigger, TriggerEvent
class ContainerState(str, Enum):
"""
Possible container states.
See https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-phase.
"""
WAITING = "waiting"
RUNNING = "running"
TERMINATED = "terminated"
FAILED = "failed"
UNDEFINED = "undefined"
class KubernetesPodTrigger(BaseTrigger):
"""
KubernetesPodTrigger run on the trigger worker to check the state of Pod.
:param pod_name: The name of the pod.
:param pod_namespace: The namespace of the pod.
:param kubernetes_conn_id: The :ref:`kubernetes connection id <howto/connection:kubernetes>`
for the Kubernetes cluster.
:param cluster_context: Context that points to kubernetes cluster.
:param config_file: Path to kubeconfig file.
:param poll_interval: Polling period in seconds to check for the status.
:param trigger_start_time: time in Datetime format when the trigger was started
:param in_cluster: run kubernetes client with in_cluster configuration.
:param get_logs: get the stdout of the container as logs of the tasks.
:param startup_timeout: timeout in seconds to start up the pod.
:param on_finish_action: What to do when the pod reaches its final state, or the execution is interrupted.
If "delete_pod", the pod will be deleted regardless it's state; if "delete_succeeded_pod",
only succeeded pod will be deleted. You can set to "keep_pod" to keep the pod.
:param should_delete_pod: What to do when the pod reaches its final
state, or the execution is interrupted. If True (default), delete the
pod; if False, leave the pod.
Deprecated - use `on_finish_action` instead.
"""
def __init__(
self,
pod_name: str,
pod_namespace: str,
trigger_start_time: datetime,
base_container_name: str,
kubernetes_conn_id: str | None = None,
poll_interval: float = 2,
cluster_context: str | None = None,
config_file: str | None = None,
in_cluster: bool | None = None,
get_logs: bool = True,
startup_timeout: int = 120,
on_finish_action: str = "delete_pod",
should_delete_pod: bool | None = None,
):
super().__init__()
self.pod_name = pod_name
self.pod_namespace = pod_namespace
self.trigger_start_time = trigger_start_time
self.base_container_name = base_container_name
self.kubernetes_conn_id = kubernetes_conn_id
self.poll_interval = poll_interval
self.cluster_context = cluster_context
self.config_file = config_file
self.in_cluster = in_cluster
self.get_logs = get_logs
self.startup_timeout = startup_timeout
if should_delete_pod is not None:
warnings.warn(
"`should_delete_pod` parameter is deprecated, please use `on_finish_action`",
AirflowProviderDeprecationWarning,
)
self.on_finish_action = (
OnFinishAction.DELETE_POD if should_delete_pod else OnFinishAction.KEEP_POD
)
self.should_delete_pod = should_delete_pod
else:
self.on_finish_action = OnFinishAction(on_finish_action)
self.should_delete_pod = self.on_finish_action == OnFinishAction.DELETE_POD
self._hook: AsyncKubernetesHook | None = None
self._since_time = None
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes KubernetesCreatePodTrigger arguments and classpath."""
return (
"airflow.providers.cncf.kubernetes.triggers.pod.KubernetesPodTrigger",
{
"pod_name": self.pod_name,
"pod_namespace": self.pod_namespace,
"base_container_name": self.base_container_name,
"kubernetes_conn_id": self.kubernetes_conn_id,
"poll_interval": self.poll_interval,
"cluster_context": self.cluster_context,
"config_file": self.config_file,
"in_cluster": self.in_cluster,
"get_logs": self.get_logs,
"startup_timeout": self.startup_timeout,
"trigger_start_time": self.trigger_start_time,
"should_delete_pod": self.should_delete_pod,
"on_finish_action": self.on_finish_action.value,
},
)
async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Gets current pod status and yields a TriggerEvent."""
hook = self._get_async_hook()
self.log.info("Checking pod %r in namespace %r.", self.pod_name, self.pod_namespace)
while True:
try:
pod = await hook.get_pod(
name=self.pod_name,
namespace=self.pod_namespace,
)
pod_status = pod.status.phase
self.log.debug("Pod %s status: %s", self.pod_name, pod_status)
container_state = self.define_container_state(pod)
self.log.debug("Container %s status: %s", self.base_container_name, container_state)
if container_state == ContainerState.TERMINATED:
yield TriggerEvent(
{
"name": self.pod_name,
"namespace": self.pod_namespace,
"status": "success",
"message": "All containers inside pod have started successfully.",
}
)
return
elif self.should_wait(pod_phase=pod_status, container_state=container_state):
self.log.info("Container is not completed and still working.")
if pod_status == PodPhase.PENDING and container_state == ContainerState.UNDEFINED:
delta = datetime.now(tz=pytz.UTC) - self.trigger_start_time
if delta.total_seconds() >= self.startup_timeout:
message = (
f"Pod took longer than {self.startup_timeout} seconds to start. "
"Check the pod events in kubernetes to determine why."
)
yield TriggerEvent(
{
"name": self.pod_name,
"namespace": self.pod_namespace,
"status": "timeout",
"message": message,
}
)
return
self.log.info("Sleeping for %s seconds.", self.poll_interval)
await asyncio.sleep(self.poll_interval)
else:
yield TriggerEvent(
{
"name": self.pod_name,
"namespace": self.pod_namespace,
"status": "failed",
"message": pod.status.message,
}
)
return
except CancelledError:
# That means that task was marked as failed
if self.get_logs:
self.log.info("Outputting container logs...")
await self._get_async_hook().read_logs(
name=self.pod_name,
namespace=self.pod_namespace,
)
if self.on_finish_action == OnFinishAction.DELETE_POD:
self.log.info("Deleting pod...")
await self._get_async_hook().delete_pod(
name=self.pod_name,
namespace=self.pod_namespace,
)
yield TriggerEvent(
{
"name": self.pod_name,
"namespace": self.pod_namespace,
"status": "cancelled",
"message": "Pod execution was cancelled",
}
)
return
except Exception as e:
self.log.exception("Exception occurred while checking pod phase:")
yield TriggerEvent(
{
"name": self.pod_name,
"namespace": self.pod_namespace,
"status": "error",
"message": str(e),
}
)
return
def _get_async_hook(self) -> AsyncKubernetesHook:
if self._hook is None:
self._hook = AsyncKubernetesHook(
conn_id=self.kubernetes_conn_id,
in_cluster=self.in_cluster,
config_file=self.config_file,
cluster_context=self.cluster_context,
)
return self._hook
def define_container_state(self, pod: V1Pod) -> ContainerState:
pod_containers = pod.status.container_statuses
if pod_containers is None:
return ContainerState.UNDEFINED
container = [c for c in pod_containers if c.name == self.base_container_name][0]
for state in (ContainerState.RUNNING, ContainerState.WAITING, ContainerState.TERMINATED):
state_obj = getattr(container.state, state)
if state_obj is not None:
if state != ContainerState.TERMINATED:
return state
else:
return ContainerState.TERMINATED if state_obj.exit_code == 0 else ContainerState.FAILED
return ContainerState.UNDEFINED
@staticmethod
def should_wait(pod_phase: PodPhase, container_state: ContainerState) -> bool:
return (
container_state == ContainerState.WAITING
or container_state == ContainerState.RUNNING
or (container_state == ContainerState.UNDEFINED and pod_phase == PodPhase.PENDING)
)