diff --git a/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py b/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py index a53af7d6e30e9..9bc821a816d99 100644 --- a/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py +++ b/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py @@ -76,10 +76,20 @@ def convert_env_vars(env_vars: list[k8s.V1EnvVar] | dict[str, str]) -> list[k8s. If the collection is a str-str dict, convert it into a list of ``V1EnvVar``s. """ - if isinstance(env_vars, list): - return env_vars if isinstance(env_vars, dict): return [k8s.V1EnvVar(name=k, value=v) for k, v in env_vars.items()] + return env_vars + + +def convert_env_vars_or_raise_error(env_vars: list[k8s.V1EnvVar] | dict[str, str]) -> list[k8s.V1EnvVar]: + """ + Separate function to convert env var collection for kubernetes and then raise an error if it is still the wrong type. + + This is used after the template strings have been rendered. + """ + env_vars = convert_env_vars(env_vars) + if isinstance(env_vars, list): + return env_vars raise AirflowException(f"Expected dict or list, got {type(env_vars)}") diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py b/airflow/providers/cncf/kubernetes/operators/pod.py index a331987b57ff0..0b387352ae8a6 100644 --- a/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/airflow/providers/cncf/kubernetes/operators/pod.py @@ -51,6 +51,7 @@ convert_affinity, convert_configmap, convert_env_vars, + convert_env_vars_or_raise_error, convert_image_pull_secrets, convert_pod_runtime_info_env, convert_port, @@ -332,8 +333,10 @@ def __init__( self.startup_check_interval_seconds = startup_check_interval_seconds env_vars = convert_env_vars(env_vars) if env_vars else [] self.env_vars = env_vars - if pod_runtime_info_envs: - self.env_vars.extend([convert_pod_runtime_info_env(p) for p in pod_runtime_info_envs]) + pod_runtime_info_envs = ( + [convert_pod_runtime_info_env(p) for p in pod_runtime_info_envs] if pod_runtime_info_envs else [] + ) + self.pod_runtime_info_envs = pod_runtime_info_envs self.env_from = env_from or [] if configmaps: self.env_from.extend([convert_configmap(c) for c in configmaps]) @@ -985,6 +988,11 @@ def build_pod_request_obj(self, context: Context | None = None) -> k8s.V1Pod: template file. """ self.log.debug("Creating pod for KubernetesPodOperator task %s", self.task_id) + + self.env_vars = convert_env_vars_or_raise_error(self.env_vars) if self.env_vars else [] + if self.pod_runtime_info_envs: + self.env_vars.extend(self.pod_runtime_info_envs) + if self.pod_template_file: self.log.debug("Pod template file found, will parse for base pod") pod_template = pod_generator.PodGenerator.deserialize_model_file(self.pod_template_file) diff --git a/pyproject.toml b/pyproject.toml index eac3c81f8ef68..c8c64b0847cd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,7 +125,7 @@ dynamic = ["version", "optional-dependencies", "dependencies"] # # END DEPRECATED EXTRAS HERE # -# !!!!!! Those provuders are defined in the `airflow/providers//provider.yaml` files !!!!!!! +# !!!!!! Those providers are defined in the `airflow/providers//provider.yaml` files !!!!!!! # # Those extras are available as regular Airflow extras, they install provider packages in standard builds # or dependencies that are necessary to enable the feature in editable build. diff --git a/tests/providers/cncf/kubernetes/operators/test_pod.py b/tests/providers/cncf/kubernetes/operators/test_pod.py index a32a944db7321..db10bff330b01 100644 --- a/tests/providers/cncf/kubernetes/operators/test_pod.py +++ b/tests/providers/cncf/kubernetes/operators/test_pod.py @@ -235,20 +235,47 @@ def test_config_path(self, hook_mock): ) @pytest.mark.parametrize( - "input", + "input,render_template_as_native_obj,raises_error", [ - pytest.param([k8s.V1EnvVar(name="{{ bar }}", value="{{ foo }}")], id="current"), - pytest.param({"{{ bar }}": "{{ foo }}"}, id="backcompat"), + pytest.param([k8s.V1EnvVar(name="{{ bar }}", value="{{ foo }}")], False, False, id="current"), + pytest.param({"{{ bar }}": "{{ foo }}"}, False, False, id="backcompat"), + pytest.param("{{ env }}", True, False, id="xcom_args"), + pytest.param("bad env", False, True, id="error"), ], ) - def test_env_vars(self, input): + def test_env_vars(self, input, render_template_as_native_obj, raises_error): + dag = DAG( + dag_id="dag", + start_date=pendulum.now(), + render_template_as_native_obj=render_template_as_native_obj, + ) k = KubernetesPodOperator( env_vars=input, task_id="task", + name="test", + dag=dag, + ) + k.render_template_fields( + context={"foo": "footemplated", "bar": "bartemplated", "env": {"bartemplated": "footemplated"}} + ) + if raises_error: + with pytest.raises(AirflowException): + k.build_pod_request_obj() + else: + k.build_pod_request_obj() + assert k.env_vars[0].name == "bartemplated" + assert k.env_vars[0].value == "footemplated" + + def test_pod_runtime_info_envs(self): + k = KubernetesPodOperator( + task_id="task", + name="test", + pod_runtime_info_envs=[k8s.V1EnvVar(name="bar", value="foo")], ) - k.render_template_fields(context={"foo": "footemplated", "bar": "bartemplated"}) - assert k.env_vars[0].value == "footemplated" - assert k.env_vars[0].name == "bartemplated" + k.build_pod_request_obj() + + assert k.env_vars[0].name == "bar" + assert k.env_vars[0].value == "foo" def test_security_context(self): security_context = V1PodSecurityContext(run_as_user=1245)