Skip to content

Commit

Permalink
Use argparse to setup spark (#2082)
Browse files Browse the repository at this point in the history
  • Loading branch information
mathbunnyru authored Jan 17, 2024
1 parent bf33945 commit afe30f0
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 17 deletions.
10 changes: 5 additions & 5 deletions images/pyspark-notebook/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ ENV SPARK_OPTS="--driver-java-options=-Xms1024M --driver-java-options=-Xmx4096M
COPY setup_spark.py /opt/setup-scripts/

# Setup Spark
RUN SPARK_VERSION="${spark_version}" \
HADOOP_VERSION="${hadoop_version}" \
SCALA_VERSION="${scala_version}" \
SPARK_DOWNLOAD_URL="${spark_download_url}" \
/opt/setup-scripts/setup_spark.py
RUN /opt/setup-scripts/setup_spark.py \
--spark-version="${spark_version}" \
--hadoop-version="${hadoop_version}" \
--scala-version="${scala_version}" \
--spark-download-url="${spark_download_url}"

# Configure IPython system-wide
COPY ipython_kernel_config.py "/etc/ipython/"
Expand Down
29 changes: 17 additions & 12 deletions images/pyspark-notebook/setup_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

# Requirements:
# - Run as the root user
# - Required env variables: SPARK_HOME, HADOOP_VERSION, SPARK_DOWNLOAD_URL
# - Optional env variables: SPARK_VERSION, SCALA_VERSION
# - Required env variable: SPARK_HOME

import argparse
import logging
import os
import subprocess
Expand All @@ -27,13 +27,10 @@ def get_all_refs(url: str) -> list[str]:
return [a["href"] for a in soup.find_all("a", href=True)]


def get_spark_version() -> str:
def get_latest_spark_version() -> str:
"""
If ${SPARK_VERSION} env variable is non-empty, simply returns it
Otherwise, returns the last stable version of Spark using spark archive
Returns the last stable version of Spark using spark archive
"""
if (version := os.environ["SPARK_VERSION"]) != "":
return version
LOGGER.info("Downloading Spark versions information")
all_refs = get_all_refs("https://archive.apache.org/dist/spark/")
stable_versions = [
Expand Down Expand Up @@ -106,12 +103,20 @@ def configure_spark(spark_dir_name: str, spark_home: Path) -> None:
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)

spark_version = get_spark_version()
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--spark-version", required=True)
arg_parser.add_argument("--hadoop-version", required=True)
arg_parser.add_argument("--scala-version", required=True)
arg_parser.add_argument("--spark-download-url", type=Path, required=True)
args = arg_parser.parse_args()

args.spark_version = args.spark_version or get_latest_spark_version()

spark_dir_name = download_spark(
spark_version=spark_version,
hadoop_version=os.environ["HADOOP_VERSION"],
scala_version=os.environ["SCALA_VERSION"],
spark_download_url=Path(os.environ["SPARK_DOWNLOAD_URL"]),
spark_version=args.spark_version,
hadoop_version=args.hadoop_version,
scala_version=args.scala_version,
spark_download_url=args.spark_download_url,
)
configure_spark(
spark_dir_name=spark_dir_name, spark_home=Path(os.environ["SPARK_HOME"])
Expand Down

0 comments on commit afe30f0

Please sign in to comment.