Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SparkSubmit Connection Extras can be overridden #36151

Merged
merged 11 commits into from
Dec 13, 2023
13 changes: 11 additions & 2 deletions airflow/providers/apache/spark/hooks/spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,13 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
:param verbose: Whether to pass the verbose flag to spark-submit process for debugging
:param spark_binary: The command to use for spark submit.
Some distros may use spark2-submit or spark3-submit.
(will overwrite any spark_binary defined in the connection's extra JSON)
:param properties_file: Path to a file from which to load extra properties. If not
specified, this will look for conf/spark-defaults.conf.
:param queue: The name of the YARN queue to which the application is submitted.
(will overwrite any yarn queue defined in the connection's extra JSON)
:param deploy_mode: Whether to deploy your driver on the worker nodes (cluster) or locally as an client.
(will overwrite any deployment mode defined in the connection's extra JSON)
:param use_krb5ccache: if True, configure spark to use ticket cache instead of relying
on keytab for Kerberos login
"""
Expand Down Expand Up @@ -125,6 +130,8 @@ def __init__(
verbose: bool = False,
spark_binary: str | None = None,
properties_file: str | None = None,
queue: str | None = None,
deploy_mode: str | None = None,
*,
use_krb5ccache: bool = False,
) -> None:
Expand Down Expand Up @@ -159,6 +166,8 @@ def __init__(
self._kubernetes_driver_pod: str | None = None
self.spark_binary = spark_binary
self._properties_file = properties_file
self._queue = queue
self._deploy_mode = deploy_mode
self._connection = self._resolve_connection()
self._is_yarn = "yarn" in self._connection["master"]
self._is_kubernetes = "k8s" in self._connection["master"]
Expand Down Expand Up @@ -204,8 +213,8 @@ def _resolve_connection(self) -> dict[str, Any]:

# Determine optional yarn queue from the extra field
extra = conn.extra_dejson
conn_data["queue"] = extra.get("queue")
conn_data["deploy_mode"] = extra.get("deploy-mode")
conn_data["queue"] = self._queue if self._queue else extra.get("queue")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Override the values if passed otherwise fetch from Spark's Connection Extras

conn_data["deploy_mode"] = self._deploy_mode if self._deploy_mode else extra.get("deploy-mode")
if not self.spark_binary:
self.spark_binary = extra.get("spark-binary", "spark-submit")
if self.spark_binary is not None and self.spark_binary not in ALLOWED_SPARK_BINARIES:
Expand Down
11 changes: 11 additions & 0 deletions airflow/providers/apache/spark/operators/spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,13 @@ class SparkSubmitOperator(BaseOperator):
:param verbose: Whether to pass the verbose flag to spark-submit process for debugging
:param spark_binary: The command to use for spark submit.
Some distros may use spark2-submit or spark3-submit.
(will overwrite any spark_binary defined in the connection's extra JSON)
:param properties_file: Path to a file from which to load extra properties. If not
specified, this will look for conf/spark-defaults.conf.
:param queue: The name of the YARN queue to which the application is submitted.
(will overwrite any yarn queue defined in the connection's extra JSON)
:param deploy_mode: Whether to deploy your driver on the worker nodes (cluster) or locally as a client.
(will overwrite any deployment mode defined in the connection's extra JSON)
:param use_krb5ccache: if True, configure spark to use ticket cache instead of relying
on keytab for Kerberos login
"""
Expand Down Expand Up @@ -124,6 +129,8 @@ def __init__(
verbose: bool = False,
spark_binary: str | None = None,
properties_file: str | None = None,
queue: str | None = None,
deploy_mode: str | None = None,
use_krb5ccache: bool = False,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -154,6 +161,8 @@ def __init__(
self._verbose = verbose
self._spark_binary = spark_binary
self._properties_file = properties_file
self._queue = queue
self._deploy_mode = deploy_mode
self._hook: SparkSubmitHook | None = None
self._conn_id = conn_id
self._use_krb5ccache = use_krb5ccache
Expand Down Expand Up @@ -197,5 +206,7 @@ def _get_hook(self) -> SparkSubmitHook:
verbose=self._verbose,
spark_binary=self._spark_binary,
properties_file=self._properties_file,
queue=self._queue,
deploy_mode=self._deploy_mode,
use_krb5ccache=self._use_krb5ccache,
)
29 changes: 29 additions & 0 deletions tests/providers/apache/spark/operators/test_spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class TestSparkSubmitOperator:
"args should keep embedded spaces",
],
"use_krb5ccache": True,
"queue": "yarn_dev_queue2",
"deploy_mode": "client2",
}

def setup_method(self):
Expand Down Expand Up @@ -120,6 +122,8 @@ def test_execute(self):
"args should keep embedded spaces",
],
"spark_binary": "sparky",
"queue": "yarn_dev_queue2",
"deploy_mode": "client2",
"use_krb5ccache": True,
"properties_file": "conf/spark-custom.conf",
}
Expand Down Expand Up @@ -149,9 +153,34 @@ def test_execute(self):
assert expected_dict["driver_memory"] == operator._driver_memory
assert expected_dict["application_args"] == operator._application_args
assert expected_dict["spark_binary"] == operator._spark_binary
assert expected_dict["queue"] == operator._queue
assert expected_dict["deploy_mode"] == operator._deploy_mode
assert expected_dict["properties_file"] == operator._properties_file
assert expected_dict["use_krb5ccache"] == operator._use_krb5ccache

@pytest.mark.db_test
def test_spark_submit_cmd_connection_overrides(self):
config = self._config
# have to add this otherwise we can't run
# _build_spark_submit_command
config["use_krb5ccache"] = False
operator = SparkSubmitOperator(
task_id="spark_submit_job", spark_binary="sparky", dag=self.dag, **config
)
cmd = " ".join(operator._get_hook()._build_spark_submit_command("test"))
assert "--queue yarn_dev_queue2" in cmd
assert "--deploy-mode client2" in cmd
assert "sparky" in cmd

# if we don't pass any overrides in arguments
config["queue"] = None
config["deploy_mode"] = None
operator2 = SparkSubmitOperator(task_id="spark_submit_job2", dag=self.dag, **config)
cmd2 = " ".join(operator2._get_hook()._build_spark_submit_command("test"))
assert "--queue root.default" in cmd2
assert "--deploy-mode client2" not in cmd2
assert "spark-submit" in cmd2

@pytest.mark.db_test
def test_render_template(self):
# Given
Expand Down