diff --git a/airflow/providers/apache/spark/hooks/spark_submit.py b/airflow/providers/apache/spark/hooks/spark_submit.py index 75c13c8099d4a..d519eb3e6e2b3 100644 --- a/airflow/providers/apache/spark/hooks/spark_submit.py +++ b/airflow/providers/apache/spark/hooks/spark_submit.py @@ -78,6 +78,8 @@ class SparkSubmitHook(BaseHook, LoggingMixin): :param verbose: Whether to pass the verbose flag to spark-submit process for debugging :param spark_binary: The command to use for spark submit. Some distros may use spark2-submit or spark3-submit. + :param use_krb5ccache: if True, configure spark to use ticket cache instead of relying + on keytab for Kerberos login """ conn_name_attr = "conn_id" @@ -120,6 +122,8 @@ def __init__( env_vars: dict[str, Any] | None = None, verbose: bool = False, spark_binary: str | None = None, + *, + use_krb5ccache: bool = False, ) -> None: super().__init__() self._conf = conf or {} @@ -138,7 +142,8 @@ def __init__( self._executor_memory = executor_memory self._driver_memory = driver_memory self._keytab = keytab - self._principal = principal + self._principal = self._resolve_kerberos_principal(principal) if use_krb5ccache else principal + self._use_krb5ccache = use_krb5ccache self._proxy_user = proxy_user self._name = name self._num_executors = num_executors @@ -317,6 +322,12 @@ def _build_spark_submit_command(self, application: str) -> list[str]: connection_cmd += ["--keytab", self._keytab] if self._principal: connection_cmd += ["--principal", self._principal] + if self._use_krb5ccache: + if not os.getenv("KRB5CCNAME"): + raise AirflowException( + "KRB5CCNAME environment variable required to use ticket ccache is missing." + ) + connection_cmd += ["--conf", "spark.kerberos.renewal.credentials=ccache"] if self._proxy_user: connection_cmd += ["--proxy-user", self._proxy_user] if self._name: @@ -383,6 +394,26 @@ def _build_track_driver_status_command(self) -> list[str]: return connection_cmd + def _resolve_kerberos_principal(self, principal: str | None) -> str: + """Resolve kerberos principal if airflow > 2.8. + + TODO: delete when min airflow version >= 2.8 and import directly from airflow.security.kerberos + """ + from packaging.version import Version + + from airflow.version import version + + if Version(version) < Version("2.8"): + from airflow.utils.net import get_hostname + + return principal or airflow_conf.get_mandatory_value("kerberos", "principal").replace( + "_HOST", get_hostname() + ) + else: + from airflow.security.kerberos import get_kerberos_principle + + return get_kerberos_principle(principal) + def submit(self, application: str = "", **kwargs: Any) -> None: """ Remote Popen to execute the spark-submit job. diff --git a/tests/providers/apache/spark/hooks/test_spark_submit.py b/tests/providers/apache/spark/hooks/test_spark_submit.py index d3feccb081edd..9bf828e94da83 100644 --- a/tests/providers/apache/spark/hooks/test_spark_submit.py +++ b/tests/providers/apache/spark/hooks/test_spark_submit.py @@ -61,6 +61,7 @@ class TestSparkSubmitHook: "args should keep embedded spaces", "baz", ], + "use_krb5ccache": True, } @staticmethod @@ -141,7 +142,10 @@ def setup_method(self): ) ) - def test_build_spark_submit_command(self): + @patch( + "airflow.providers.apache.spark.hooks.spark_submit.os.getenv", return_value="/tmp/airflow_krb5_ccache" + ) + def test_build_spark_submit_command(self, mock_get_env): # Given hook = SparkSubmitHook(**self._config) @@ -183,6 +187,8 @@ def test_build_spark_submit_command(self): "privileged_user.keytab", "--principal", "user/spark@airflow.org", + "--conf", + "spark.kerberos.renewal.credentials=ccache", "--proxy-user", "sample_user", "--name", @@ -200,6 +206,25 @@ def test_build_spark_submit_command(self): "baz", ] assert expected_build_cmd == cmd + mock_get_env.assert_called_with("KRB5CCNAME") + + @patch("airflow.configuration.conf.get_mandatory_value") + def test_resolve_spark_submit_env_vars_use_krb5ccache_missing_principal(self, mock_get_madantory_value): + mock_principle = "airflow" + mock_get_madantory_value.return_value = mock_principle + hook = SparkSubmitHook(conn_id="spark_yarn_cluster", principal=None, use_krb5ccache=True) + mock_get_madantory_value.assert_called_with("kerberos", "principal") + assert hook._principal == mock_principle + + def test_resolve_spark_submit_env_vars_use_krb5ccache_missing_KRB5CCNAME_env(self): + hook = SparkSubmitHook( + conn_id="spark_yarn_cluster", principal="user/spark@airflow.org", use_krb5ccache=True + ) + with pytest.raises( + AirflowException, + match="KRB5CCNAME environment variable required to use ticket ccache is missing.", + ): + hook._build_spark_submit_command(self._spark_job_file) def test_build_track_driver_status_command(self): # note this function is only relevant for spark setup matching below condition