Skip to content

Commit

Permalink
Add use_krb5ccache arg
Browse files Browse the repository at this point in the history
  • Loading branch information
zeotuan committed Oct 19, 2023
1 parent 8d0ad02 commit 3898c09
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 2 deletions.
33 changes: 32 additions & 1 deletion airflow/providers/apache/spark/hooks/spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
:param verbose: Whether to pass the verbose flag to spark-submit process for debugging
:param spark_binary: The command to use for spark submit.
Some distros may use spark2-submit or spark3-submit.
:param use_krb5ccache: if True, configure spark to use ticket cache instead of relying
on keytab for Kerberos login
"""

conn_name_attr = "conn_id"
Expand Down Expand Up @@ -120,6 +122,8 @@ def __init__(
env_vars: dict[str, Any] | None = None,
verbose: bool = False,
spark_binary: str | None = None,
*,
use_krb5ccache: bool = False,
) -> None:
super().__init__()
self._conf = conf or {}
Expand All @@ -138,7 +142,8 @@ def __init__(
self._executor_memory = executor_memory
self._driver_memory = driver_memory
self._keytab = keytab
self._principal = principal
self._principal = self._resolve_kerberos_principal(principal) if use_krb5ccache else principal
self._use_krb5ccache = use_krb5ccache
self._proxy_user = proxy_user
self._name = name
self._num_executors = num_executors
Expand Down Expand Up @@ -317,6 +322,12 @@ def _build_spark_submit_command(self, application: str) -> list[str]:
connection_cmd += ["--keytab", self._keytab]
if self._principal:
connection_cmd += ["--principal", self._principal]
if self._use_krb5ccache:
if not os.getenv("KRB5CCNAME"):
raise AirflowException(
"KRB5CCNAME environment variable required to use ticket ccache is missing."
)
connection_cmd += ["--conf", "spark.kerberos.renewal.credentials=ccache"]
if self._proxy_user:
connection_cmd += ["--proxy-user", self._proxy_user]
if self._name:
Expand Down Expand Up @@ -383,6 +394,26 @@ def _build_track_driver_status_command(self) -> list[str]:

return connection_cmd

def _resolve_kerberos_principal(self, principal: str | None) -> str:
"""Resolve kerberos principal if airflow > 2.8.
TODO: delete when min airflow version >= 2.8 and import directly from airflow.security.kerberos
"""
from packaging.version import Version

from airflow.version import version

if Version(version) < Version("2.8"):
from airflow.utils.net import get_hostname

return principal or airflow_conf.get_mandatory_value("kerberos", "principal").replace(
"_HOST", get_hostname()
)
else:
from airflow.security.kerberos import get_kerberos_principle

return get_kerberos_principle(principal)

def submit(self, application: str = "", **kwargs: Any) -> None:
"""
Remote Popen to execute the spark-submit job.
Expand Down
27 changes: 26 additions & 1 deletion tests/providers/apache/spark/hooks/test_spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class TestSparkSubmitHook:
"args should keep embedded spaces",
"baz",
],
"use_krb5ccache": True,
}

@staticmethod
Expand Down Expand Up @@ -141,7 +142,10 @@ def setup_method(self):
)
)

def test_build_spark_submit_command(self):
@patch(
"airflow.providers.apache.spark.hooks.spark_submit.os.getenv", return_value="/tmp/airflow_krb5_ccache"
)
def test_build_spark_submit_command(self, mock_get_env):
# Given
hook = SparkSubmitHook(**self._config)

Expand Down Expand Up @@ -183,6 +187,8 @@ def test_build_spark_submit_command(self):
"privileged_user.keytab",
"--principal",
"user/[email protected]",
"--conf",
"spark.kerberos.renewal.credentials=ccache",
"--proxy-user",
"sample_user",
"--name",
Expand All @@ -200,6 +206,25 @@ def test_build_spark_submit_command(self):
"baz",
]
assert expected_build_cmd == cmd
mock_get_env.assert_called_with("KRB5CCNAME")

@patch("airflow.configuration.conf.get_mandatory_value")
def test_resolve_spark_submit_env_vars_use_krb5ccache_missing_principal(self, mock_get_madantory_value):
mock_principle = "airflow"
mock_get_madantory_value.return_value = mock_principle
hook = SparkSubmitHook(conn_id="spark_yarn_cluster", principal=None, use_krb5ccache=True)
mock_get_madantory_value.assert_called_with("kerberos", "principal")
assert hook._principal == mock_principle

def test_resolve_spark_submit_env_vars_use_krb5ccache_missing_KRB5CCNAME_env(self):
hook = SparkSubmitHook(
conn_id="spark_yarn_cluster", principal="user/[email protected]", use_krb5ccache=True
)
with pytest.raises(
AirflowException,
match="KRB5CCNAME environment variable required to use ticket ccache is missing.",
):
hook._build_spark_submit_command(self._spark_job_file)

def test_build_track_driver_status_command(self):
# note this function is only relevant for spark setup matching below condition
Expand Down

0 comments on commit 3898c09

Please sign in to comment.