diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index 5073a86dd68a..c597b4a9e962 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -1315,7 +1315,8 @@ def _add_argparse_args(cls, parser): 'the execution.') parser.add_argument( '--spark_job_server_jar', - help='Path or URL to a Beam Spark jobserver jar.') + help='Path or URL to a Beam Spark job server jar. ' + 'Overrides --spark_version.') parser.add_argument( '--spark_submit_uber_jar', default=False, @@ -1328,6 +1329,11 @@ def _add_argparse_args(cls, parser): help='URL for the Spark REST endpoint. ' 'Only required when using spark_submit_uber_jar. ' 'For example, http://hostname:6066') + parser.add_argument( + '--spark_version', + default='2', + choices=['2', '3'], + help='Spark major version to use.') class TestOptions(PipelineOptions): diff --git a/sdks/python/apache_beam/runners/portability/spark_runner.py b/sdks/python/apache_beam/runners/portability/spark_runner.py index 11455073725a..bb7b4c0465e2 100644 --- a/sdks/python/apache_beam/runners/portability/spark_runner.py +++ b/sdks/python/apache_beam/runners/portability/spark_runner.py @@ -77,6 +77,7 @@ def __init__(self, options): options = options.view_as(pipeline_options.SparkRunnerOptions) self._jar = options.spark_job_server_jar self._master_url = options.spark_master_url + self._spark_version = options.spark_version def path_to_jar(self): if self._jar: @@ -91,6 +92,8 @@ def path_to_jar(self): self._jar) return self._jar else: + if self._spark_version == '3': + return self.path_to_beam_jar(':runners:spark:3:job-server:shadowJar') return self.path_to_beam_jar( ':runners:spark:2:job-server:shadowJar', artifact_id='beam-runners-spark-job-server') diff --git a/sdks/python/apache_beam/runners/portability/spark_uber_jar_job_server.py b/sdks/python/apache_beam/runners/portability/spark_uber_jar_job_server.py index 3adcabec20f3..60b2e88357c9 100644 --- a/sdks/python/apache_beam/runners/portability/spark_uber_jar_job_server.py +++ b/sdks/python/apache_beam/runners/portability/spark_uber_jar_job_server.py @@ -47,12 +47,12 @@ class SparkUberJarJobServer(abstract_job_service.AbstractJobServiceServicer): def __init__(self, rest_url, options): super(SparkUberJarJobServer, self).__init__() self._rest_url = rest_url - self._executable_jar = ( - options.view_as( - pipeline_options.SparkRunnerOptions).spark_job_server_jar) self._artifact_port = ( options.view_as(pipeline_options.JobServerOptions).artifact_port) self._temp_dir = tempfile.mkdtemp(prefix='apache-beam-spark') + spark_options = options.view_as(pipeline_options.SparkRunnerOptions) + self._executable_jar = spark_options.spark_job_server_jar + self._spark_version = spark_options.spark_version def start(self): return self @@ -73,9 +73,13 @@ def executable_jar(self): self._executable_jar) url = self._executable_jar else: - url = job_server.JavaJarJobServer.path_to_beam_jar( - ':runners:spark:2:job-server:shadowJar', - artifact_id='beam-runners-spark-job-server') + if self._spark_version == '3': + url = job_server.JavaJarJobServer.path_to_beam_jar( + ':runners:spark:3:job-server:shadowJar') + else: + url = job_server.JavaJarJobServer.path_to_beam_jar( + ':runners:spark:2:job-server:shadowJar', + artifact_id='beam-runners-spark-job-server') return job_server.JavaJarJobServer.local_jar(url) def create_beam_job(self, job_id, job_name, pipeline, options):