diff --git a/src/databricks/labs/pylint/all.py b/src/databricks/labs/pylint/all.py index 632a3fa..b5f25ae 100644 --- a/src/databricks/labs/pylint/all.py +++ b/src/databricks/labs/pylint/all.py @@ -2,6 +2,7 @@ from databricks.labs.pylint.dbutils import DbutilsChecker from databricks.labs.pylint.legacy import LegacyChecker from databricks.labs.pylint.notebooks import NotebookChecker +from databricks.labs.pylint.spark import SparkChecker def register(linter): @@ -9,3 +10,4 @@ def register(linter): linter.register_checker(DbutilsChecker(linter)) linter.register_checker(LegacyChecker(linter)) linter.register_checker(AirflowChecker(linter)) + linter.register_checker(SparkChecker(linter)) diff --git a/src/databricks/labs/pylint/spark.py b/src/databricks/labs/pylint/spark.py new file mode 100644 index 0000000..7bc54de --- /dev/null +++ b/src/databricks/labs/pylint/spark.py @@ -0,0 +1,40 @@ +import astroid +from pylint.checkers import BaseChecker + + +class SparkChecker(BaseChecker): + name = "spark" + + msgs = { + "E9700": ( + "Using spark outside the function is leading to untestable code", + "spark-outside-function", + "spark used outside of function", + ), + "E9701": ( + "Function %s is missing a 'spark' argument", + "no-spark-argument-in-function", + "function missing spark argument", + ), + } + + def visit_name(self, node: astroid.Name): + if node.name != "spark": + return + in_node = node + while in_node and not isinstance(in_node, astroid.FunctionDef): + in_node = in_node.parent + if not in_node: + self.add_message("spark-outside-function", node=node) + return + has_spark_arg = False + for arg in in_node.args.arguments: + if arg.name == "spark": + has_spark_arg = True + break + if not has_spark_arg: + self.add_message("no-spark-argument-in-function", node=in_node, args=(in_node.name,)) + + +def register(linter): + linter.register_checker(SparkChecker(linter)) diff --git a/tests/samples/p/percent_run.py b/tests/samples/p/percent_run.py index ebe9e14..3dd248c 100644 --- a/tests/samples/p/percent_run.py +++ b/tests/samples/p/percent_run.py @@ -12,7 +12,7 @@ # COMMAND ---------- -df = spark.table("samples.nyctaxi.trips").limit(10) +df = spark.table("samples.nyctaxi.trips").limit(10) # [spark-outside-function] display(df) # COMMAND ---------- @@ -21,5 +21,5 @@ # COMMAND ---------- -df = spark.table("samples.nyctaxi.trips").limit(10) +df = spark.table("samples.nyctaxi.trips").limit(10) # [spark-outside-function] display(df) diff --git a/tests/samples/p/percent_run.txt b/tests/samples/p/percent_run.txt index 860bcd9..8219222 100644 --- a/tests/samples/p/percent_run.txt +++ b/tests/samples/p/percent_run.txt @@ -1 +1,3 @@ -notebooks-percent-run:7:0::::Using %run is not allowed:UNDEFINED \ No newline at end of file +notebooks-percent-run:7:0:None:None::Using %run is not allowed:UNDEFINED +spark-outside-function:15:5:15:10::Using spark outside the function is leading to untestable code:UNDEFINED +spark-outside-function:24:5:24:10::Using spark outside the function is leading to untestable code:UNDEFINED diff --git a/tests/test_spark.py b/tests/test_spark.py new file mode 100644 index 0000000..f113ad5 --- /dev/null +++ b/tests/test_spark.py @@ -0,0 +1,39 @@ +from databricks.labs.pylint.spark import SparkChecker + + +def test_spark_inside_function(lint_with): + messages = ( + lint_with(SparkChecker) + << """def do_something(spark, x): + for i in range(10): + if i > 3: + continue + spark #@ +""" + ) + assert not messages + + +def test_spark_outside_function(lint_with): + messages = ( + lint_with(SparkChecker) + << """for i in range(10): + if i > 3: + continue + spark #@ +""" + ) + assert "[spark-outside-function] Using spark outside the function is leading to untestable code" in messages + + +def test_spark_inside_of_function_but_not_in_args(lint_with): + messages = ( + lint_with(SparkChecker) + << """def do_something(x): + for i in range(10): + if i > 3: + continue + spark #@ +""" + ) + assert "[no-spark-argument-in-function] Function do_something is missing a 'spark' argument" in messages