From 6b51cb1f2c6748781f6193df613c1f47e067f39f Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Fri, 27 Sep 2024 18:38:44 +0200 Subject: [PATCH 01/15] Change logic of direct filesystem access linting --- .../labs/ucx/source_code/linters/directfs.py | 66 ++++--------------- .../unit/source_code/linters/test_directfs.py | 15 +++-- .../file-access/complex-sql-notebook.sql | 8 --- .../functional/file-access/direct-fs3.py | 2 - 4 files changed, 23 insertions(+), 68 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/linters/directfs.py b/src/databricks/labs/ucx/source_code/linters/directfs.py index 91611c2adb..dbaa4371a1 100644 --- a/src/databricks/labs/ucx/source_code/linters/directfs.py +++ b/src/databricks/labs/ucx/source_code/linters/directfs.py @@ -3,7 +3,7 @@ from abc import ABC from collections.abc import Iterable -from astroid import Attribute, Call, Const, InferenceError, JoinedStr, Name, NodeNG # type: ignore +from astroid import Call, InferenceError, NodeNG # type: ignore from sqlglot import Expression as SqlExpression, parse as parse_sql, ParseError as SqlParseError from sqlglot.expressions import Alter, Create, Delete, Drop, Identifier, Insert, Literal, Select @@ -72,43 +72,36 @@ class _DetectDirectFsAccessVisitor(TreeVisitor): def __init__(self, session_state: CurrentSessionState, prevent_spark_duplicates: bool) -> None: self._session_state = session_state self._directfs_nodes: list[DirectFsAccessNode] = [] - self._reported_locations: set[tuple[int, int]] = set() self._prevent_spark_duplicates = prevent_spark_duplicates def visit_call(self, node: Call): for arg in node.args: - self._visit_arg(arg) + self._visit_arg(node, arg) - def _visit_arg(self, arg: NodeNG): + def _visit_arg(self, call: Call, arg: NodeNG): try: for inferred in InferredValue.infer_from_node(arg, self._session_state): if not inferred.is_inferred(): logger.debug(f"Could not infer value of {arg.as_string()}") continue - self._check_str_constant(arg, inferred) + self._check_str_arg(call, arg, inferred) except InferenceError as e: logger.debug(f"Could not infer value of {arg.as_string()}", exc_info=e) - def visit_const(self, node: Const): - # Constant strings yield Advisories - if isinstance(node.value, str): - self._check_str_constant(node, InferredValue([node])) - - def _check_str_constant(self, source_node: NodeNG, inferred: InferredValue): - if self._already_reported(source_node, inferred): - return - # don't report on JoinedStr fragments - if isinstance(source_node.parent, JoinedStr): - return + def _check_str_arg(self, call_node: Call, arg_node: NodeNG, inferred: InferredValue): value = inferred.as_string() for pattern in DIRECT_FS_ACCESS_PATTERNS: if not pattern.matches(value): continue - # avoid false positives with relative URLs - if self._is_http_call_parameter(source_node): + # don't capture calls not originating from spark or dbutils + is_from_spark = False + is_from_db_utils = Tree(call_node).is_from_module("dbutils") + if not is_from_db_utils: + is_from_spark = Tree(call_node).is_from_module("spark") + if not is_from_db_utils and not is_from_spark: return # avoid duplicate advices that are reported by SparkSqlPyLinter - if self._prevent_spark_duplicates and Tree(source_node).is_from_module("spark"): + if self._prevent_spark_duplicates and is_from_spark: return # since we're normally filtering out spark calls, we're dealing with dfsas we know little about # notably we don't know is_read or is_write @@ -117,39 +110,8 @@ def _check_str_constant(self, source_node: NodeNG, inferred: InferredValue): is_read=True, is_write=False, ) - self._directfs_nodes.append(DirectFsAccessNode(dfsa, source_node)) - self._reported_locations.add((source_node.lineno, source_node.col_offset)) - - @classmethod - def _is_http_call_parameter(cls, source_node: NodeNG): - if not isinstance(source_node.parent, Call): - return False - # for now we only cater for ws.api_client.do - return cls._is_ws_api_client_do_call(source_node) - - @classmethod - def _is_ws_api_client_do_call(cls, source_node: NodeNG): - assert isinstance(source_node.parent, Call) - func = source_node.parent.func - if not isinstance(func, Attribute) or func.attrname != "do": - return False - expr = func.expr - if not isinstance(expr, Attribute) or expr.attrname != "api_client": - return False - expr = expr.expr - if not isinstance(expr, Name): - return False - for value in InferredValue.infer_from_node(expr): - if not value.is_inferred(): - continue - for node in value.nodes: - return Tree(node).is_instance_of("WorkspaceClient") - # at this point is seems safer to assume that expr.expr is a workspace than the opposite - return True - - def _already_reported(self, source_node: NodeNG, inferred: InferredValue): - all_nodes = [source_node] + inferred.nodes - return any((node.lineno, node.col_offset) in self._reported_locations for node in all_nodes) + self._directfs_nodes.append(DirectFsAccessNode(dfsa, arg_node)) + return @property def directfs_nodes(self): diff --git a/tests/unit/source_code/linters/test_directfs.py b/tests/unit/source_code/linters/test_directfs.py index 2e482d8bb0..ba50e177b6 100644 --- a/tests/unit/source_code/linters/test_directfs.py +++ b/tests/unit/source_code/linters/test_directfs.py @@ -29,11 +29,14 @@ def test_matches_dfsa_pattern(path, matches): "code, expected", [ ('SOME_CONSTANT = "not a file system path"', 0), - ('SOME_CONSTANT = ("/dbfs/mnt", "dbfs:/", "/mnt/")', 3), + ('SOME_CONSTANT = ("/dbfs/mnt", "dbfs:/", "/mnt/")', 0), ('# "/dbfs/mnt"', 0), - ('SOME_CONSTANT = "/dbfs/mnt"', 1), - ('SOME_CONSTANT = "/dbfs/mnt"; load_data(SOME_CONSTANT)', 1), + ('SOME_CONSTANT = "/dbfs/mnt"', 0), + ('SOME_CONSTANT = "/dbfs/mnt"; load_data(SOME_CONSTANT)', 0), + ('SOME_CONSTANT = "/dbfs/mnt"; spark.table(SOME_CONSTANT)', 1), + ('SOME_CONSTANT = ("/dbfs/mnt", "dbfs:/", "/mnt/"); [dbutils.fs(path) for path in SOME_CONSTANT]', 3), ('SOME_CONSTANT = 42; load_data(SOME_CONSTANT)', 0), + ('SOME_CONSTANT = "/dbfs/mnt"; dbutils.fs(SOME_CONSTANT)', 1), ], ) def test_detects_dfsa_paths(code, expected): @@ -47,9 +50,9 @@ def test_detects_dfsa_paths(code, expected): @pytest.mark.parametrize( "code, expected", [ - ("load_data('/dbfs/mnt/data')", 1), - ("load_data('/data')", 1), - ("load_data('/dbfs/mnt/data', '/data')", 2), + ("load_data('/dbfs/mnt/data')", 0), + ("dbutils.fs('/data')", 1), + ("dbutils.fs('/dbfs/mnt/data', '/data')", 2), ("# load_data('/dbfs/mnt/data', '/data')", 0), ('spark.read.parquet("/mnt/foo/bar")', 1), ('spark.read.parquet("dbfs:/mnt/foo/bar")', 1), diff --git a/tests/unit/source_code/samples/functional/file-access/complex-sql-notebook.sql b/tests/unit/source_code/samples/functional/file-access/complex-sql-notebook.sql index f77e09f4c6..f3f542360b 100644 --- a/tests/unit/source_code/samples/functional/file-access/complex-sql-notebook.sql +++ b/tests/unit/source_code/samples/functional/file-access/complex-sql-notebook.sql @@ -5,20 +5,12 @@ -- COMMAND ---------- -- DBTITLE 1,A Python cell that references DBFS -- MAGIC %python --- ucx[direct-filesystem-access:+1:7:+1:18] The use of direct filesystem references is deprecated: dbfs:/... -- MAGIC DBFS = "dbfs:/..." --- ucx[direct-filesystem-access:+1:7:+1:18] The use of direct filesystem references is deprecated: /dbfs/mnt -- MAGIC DBFS = "/dbfs/mnt" --- ucx[direct-filesystem-access:+1:7:+1:14] The use of direct filesystem references is deprecated: /mnt/ -- MAGIC DBFS = "/mnt/" --- ucx[direct-filesystem-access:+1:7:+1:18] The use of direct filesystem references is deprecated: dbfs:/... -- MAGIC DBFS = "dbfs:/..." --- ucx[direct-filesystem-access:+1:10:+1:26] The use of direct filesystem references is deprecated: /dbfs/mnt/data -- MAGIC load_data('/dbfs/mnt/data') --- ucx[direct-filesystem-access:+1:10:+1:17] The use of direct filesystem references is deprecated: /data -- MAGIC load_data('/data') --- ucx[direct-filesystem-access:+2:10:+2:26] The use of direct filesystem references is deprecated: /dbfs/mnt/data --- ucx[direct-filesystem-access:+1:28:+1:35] The use of direct filesystem references is deprecated: /data -- MAGIC load_data('/dbfs/mnt/data', '/data') -- MAGIC # load_data('/dbfs/mnt/data', '/data') -- ucx[direct-filesystem-access:+1:0:+1:34] The use of direct filesystem references is deprecated: /mnt/foo/bar diff --git a/tests/unit/source_code/samples/functional/file-access/direct-fs3.py b/tests/unit/source_code/samples/functional/file-access/direct-fs3.py index 0db9d9a2f1..c9a59ccd49 100644 --- a/tests/unit/source_code/samples/functional/file-access/direct-fs3.py +++ b/tests/unit/source_code/samples/functional/file-access/direct-fs3.py @@ -1,6 +1,4 @@ -# ucx[direct-filesystem-access:+1:6:+1:26] The use of direct filesystem references is deprecated: dbfs:/mnt/foo/bar1 DBFS1="dbfs:/mnt/foo/bar1" -# ucx[direct-filesystem-access:+1:16:+1:36] The use of direct filesystem references is deprecated: dbfs:/mnt/foo/bar2 systems=[DBFS1, "dbfs:/mnt/foo/bar2"] for system in systems: # ucx[direct-filesystem-access:+2:4:+2:30] The use of direct filesystem references is deprecated: dbfs:/mnt/foo/bar1 From fe0f9425eae3cf75f5a7d43336eeb58efffce3a5 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Tue, 1 Oct 2024 11:20:47 +0200 Subject: [PATCH 02/15] more comments --- src/databricks/labs/ucx/source_code/linters/directfs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/databricks/labs/ucx/source_code/linters/directfs.py b/src/databricks/labs/ucx/source_code/linters/directfs.py index dbaa4371a1..0023ffee23 100644 --- a/src/databricks/labs/ucx/source_code/linters/directfs.py +++ b/src/databricks/labs/ucx/source_code/linters/directfs.py @@ -93,7 +93,8 @@ def _check_str_arg(self, call_node: Call, arg_node: NodeNG, inferred: InferredVa for pattern in DIRECT_FS_ACCESS_PATTERNS: if not pattern.matches(value): continue - # don't capture calls not originating from spark or dbutils + # only capture calls originating from spark or dbutils + # because there is no other known way to manipulate data directly from file system is_from_spark = False is_from_db_utils = Tree(call_node).is_from_module("dbutils") if not is_from_db_utils: From 51afb4ddfe49e9c8ce5e5774e658e0fa1876e03f Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Wed, 2 Oct 2024 11:17:52 +0200 Subject: [PATCH 03/15] improve readability --- src/databricks/labs/ucx/source_code/linters/directfs.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/linters/directfs.py b/src/databricks/labs/ucx/source_code/linters/directfs.py index 0023ffee23..0bb07dd1c4 100644 --- a/src/databricks/labs/ucx/source_code/linters/directfs.py +++ b/src/databricks/labs/ucx/source_code/linters/directfs.py @@ -95,11 +95,9 @@ def _check_str_arg(self, call_node: Call, arg_node: NodeNG, inferred: InferredVa continue # only capture calls originating from spark or dbutils # because there is no other known way to manipulate data directly from file system - is_from_spark = False is_from_db_utils = Tree(call_node).is_from_module("dbutils") - if not is_from_db_utils: - is_from_spark = Tree(call_node).is_from_module("spark") - if not is_from_db_utils and not is_from_spark: + is_from_spark = False if is_from_db_utils else Tree(call_node).is_from_module("spark") + if not (is_from_db_utils or is_from_spark): return # avoid duplicate advices that are reported by SparkSqlPyLinter if self._prevent_spark_duplicates and is_from_spark: From d6b3d6ee47f297b641e533f497e81ca5ecf6df76 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Wed, 2 Oct 2024 11:21:16 +0200 Subject: [PATCH 04/15] improve readability --- src/databricks/labs/ucx/source_code/linters/directfs.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/linters/directfs.py b/src/databricks/labs/ucx/source_code/linters/directfs.py index 0bb07dd1c4..bfbfc36181 100644 --- a/src/databricks/labs/ucx/source_code/linters/directfs.py +++ b/src/databricks/labs/ucx/source_code/linters/directfs.py @@ -95,8 +95,9 @@ def _check_str_arg(self, call_node: Call, arg_node: NodeNG, inferred: InferredVa continue # only capture calls originating from spark or dbutils # because there is no other known way to manipulate data directly from file system - is_from_db_utils = Tree(call_node).is_from_module("dbutils") - is_from_spark = False if is_from_db_utils else Tree(call_node).is_from_module("spark") + tree = Tree(call_node) + is_from_db_utils = tree.is_from_module("dbutils") + is_from_spark = False if is_from_db_utils else tree.is_from_module("spark") if not (is_from_db_utils or is_from_spark): return # avoid duplicate advices that are reported by SparkSqlPyLinter From 6bc6f7e36d642606759f3f994166c6251b83dc0c Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Wed, 2 Oct 2024 12:47:44 +0200 Subject: [PATCH 05/15] add utility methods --- .../labs/ucx/source_code/python/python_ast.py | 28 +++++++++++++++++-- .../source_code/python/test_python_ast.py | 23 +++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/python/python_ast.py b/src/databricks/labs/ucx/source_code/python/python_ast.py index 5bf6781374..ba0783eebb 100644 --- a/src/databricks/labs/ucx/source_code/python/python_ast.py +++ b/src/databricks/labs/ucx/source_code/python/python_ast.py @@ -4,7 +4,7 @@ import logging import re from collections.abc import Iterable -from typing import TypeVar, cast +from typing import TypeVar, cast, Any from astroid import ( # type: ignore Assign, @@ -23,9 +23,9 @@ parse, Uninferable, ) +from astroid.modutils import BUILTIN_MODULES # type: ignore logger = logging.getLogger(__name__) - missing_handlers: set[str] = set() @@ -289,6 +289,30 @@ def renumber_node(node: NodeNG, offset: int) -> None: start = start + num_lines if start > 0 else start - num_lines return self + def get_call_name(self) -> str: + if not isinstance(self._node, Call): + return "" + func = self._node.func + if isinstance(func, Name): + return func.name + elif isinstance(func, Attribute): + return func.attrname + else: + return "" # not supported yet + + def is_builtin(self) -> bool: + if isinstance(self._node, Name): + name = self._node.name + builtins = cast(dict[str, Any], __builtins__) + if name in builtins.keys(): + return True + astroid_name = f"_{name}" + return BUILTIN_MODULES.get(astroid_name, None) is not None + if isinstance(self._node, Call): + return Tree(self._node.func).is_builtin() + if isinstance(self._node, Attribute): + return Tree(self._node.expr).is_builtin() + return False # not supported yet class _LocalTree(Tree): diff --git a/tests/unit/source_code/python/test_python_ast.py b/tests/unit/source_code/python/test_python_ast.py index 7d9f27e73e..460944bb43 100644 --- a/tests/unit/source_code/python/test_python_ast.py +++ b/tests/unit/source_code/python/test_python_ast.py @@ -233,3 +233,26 @@ def test_renumbers_negatively(): def test_counts_lines(source: str, line_count: int): tree = Tree.normalize_and_parse(source) assert tree.line_count() == line_count + + +@pytest.mark.parametrize( + "source, name, is_builtin", [ + ("x = open()", "open", True), + ("import datetime; x = datetime.datetime.now()", "now", True), + ("import stuff; x = stuff()", "stuff", False), + ("""def stuff(): + pass +x = stuff()""", "stuff", False), + ]) +def test_is_builtin(source, name, is_builtin): + tree = Tree.normalize_and_parse(source) + nodes = list(tree.node.get_children()) + for node in nodes: + if isinstance(node, Assign): + tree = Tree(node.value) + func_name = tree.get_call_name() + assert func_name == name + assert tree.is_builtin() == is_builtin + return + assert False # could not locate call + From 843b2a0d05214cf0d5f3c91f68cc2e10bdf264d0 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Wed, 2 Oct 2024 12:49:54 +0200 Subject: [PATCH 06/15] detect direct filesystem access in calls to builtin 'open' --- src/databricks/labs/ucx/source_code/linters/directfs.py | 9 +++++---- tests/unit/source_code/linters/test_directfs.py | 2 ++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/linters/directfs.py b/src/databricks/labs/ucx/source_code/linters/directfs.py index bfbfc36181..bd77f4c34b 100644 --- a/src/databricks/labs/ucx/source_code/linters/directfs.py +++ b/src/databricks/labs/ucx/source_code/linters/directfs.py @@ -93,12 +93,13 @@ def _check_str_arg(self, call_node: Call, arg_node: NodeNG, inferred: InferredVa for pattern in DIRECT_FS_ACCESS_PATTERNS: if not pattern.matches(value): continue - # only capture calls originating from spark or dbutils + # only capture 'open' calls or calls originating from spark or dbutils # because there is no other known way to manipulate data directly from file system tree = Tree(call_node) - is_from_db_utils = tree.is_from_module("dbutils") - is_from_spark = False if is_from_db_utils else tree.is_from_module("spark") - if not (is_from_db_utils or is_from_spark): + is_open = tree.get_call_name() == "open" and tree.is_builtin() + is_from_db_utils = False if is_open else tree.is_from_module("dbutils") + is_from_spark = False if is_open or is_from_db_utils else tree.is_from_module("spark") + if not (is_open or is_from_db_utils or is_from_spark): return # avoid duplicate advices that are reported by SparkSqlPyLinter if self._prevent_spark_duplicates and is_from_spark: diff --git a/tests/unit/source_code/linters/test_directfs.py b/tests/unit/source_code/linters/test_directfs.py index ba50e177b6..7fed86d23a 100644 --- a/tests/unit/source_code/linters/test_directfs.py +++ b/tests/unit/source_code/linters/test_directfs.py @@ -51,6 +51,8 @@ def test_detects_dfsa_paths(code, expected): "code, expected", [ ("load_data('/dbfs/mnt/data')", 0), + ("""with open('/dbfs/mnt/data') as f: + f.read()""", 1), ("dbutils.fs('/data')", 1), ("dbutils.fs('/dbfs/mnt/data', '/data')", 2), ("# load_data('/dbfs/mnt/data', '/data')", 0), From 97d62fae34fbe15a514e7368eb007ca4a43b1044 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Wed, 2 Oct 2024 13:02:39 +0200 Subject: [PATCH 07/15] formatting --- .../labs/ucx/source_code/linters/directfs.py | 4 ++-- .../labs/ucx/source_code/python/python_ast.py | 23 ++++++++++--------- .../unit/source_code/linters/test_directfs.py | 7 ++++-- .../source_code/python/test_python_ast.py | 22 +++++++++++------- 4 files changed, 33 insertions(+), 23 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/linters/directfs.py b/src/databricks/labs/ucx/source_code/linters/directfs.py index bd77f4c34b..6ff12bde13 100644 --- a/src/databricks/labs/ucx/source_code/linters/directfs.py +++ b/src/databricks/labs/ucx/source_code/linters/directfs.py @@ -15,7 +15,7 @@ SqlLinter, ) from databricks.labs.ucx.source_code.directfs_access import DirectFsAccess -from databricks.labs.ucx.source_code.python.python_ast import Tree, TreeVisitor +from databricks.labs.ucx.source_code.python.python_ast import Tree, TreeVisitor, TreeHelper from databricks.labs.ucx.source_code.python.python_infer import InferredValue logger = logging.getLogger(__name__) @@ -96,7 +96,7 @@ def _check_str_arg(self, call_node: Call, arg_node: NodeNG, inferred: InferredVa # only capture 'open' calls or calls originating from spark or dbutils # because there is no other known way to manipulate data directly from file system tree = Tree(call_node) - is_open = tree.get_call_name() == "open" and tree.is_builtin() + is_open = TreeHelper.get_call_name(call_node) == "open" and tree.is_builtin() is_from_db_utils = False if is_open else tree.is_from_module("dbutils") is_from_spark = False if is_open or is_from_db_utils else tree.is_from_module("spark") if not (is_open or is_from_db_utils or is_from_spark): diff --git a/src/databricks/labs/ucx/source_code/python/python_ast.py b/src/databricks/labs/ucx/source_code/python/python_ast.py index ba0783eebb..c2bc95437e 100644 --- a/src/databricks/labs/ucx/source_code/python/python_ast.py +++ b/src/databricks/labs/ucx/source_code/python/python_ast.py @@ -289,17 +289,6 @@ def renumber_node(node: NodeNG, offset: int) -> None: start = start + num_lines if start > 0 else start - num_lines return self - def get_call_name(self) -> str: - if not isinstance(self._node, Call): - return "" - func = self._node.func - if isinstance(func, Name): - return func.name - elif isinstance(func, Attribute): - return func.attrname - else: - return "" # not supported yet - def is_builtin(self) -> bool: if isinstance(self._node, Name): name = self._node.name @@ -314,6 +303,7 @@ def is_builtin(self) -> bool: return Tree(self._node.expr).is_builtin() return False # not supported yet + class _LocalTree(Tree): def is_from_module_visited(self, name: str, visited_nodes: set[NodeNG]) -> bool: @@ -322,6 +312,17 @@ def is_from_module_visited(self, name: str, visited_nodes: set[NodeNG]) -> bool: class TreeHelper(ABC): + @classmethod + def get_call_name(cls, call: Call) -> str: + if not isinstance(call, Call): + return "" + func = call.func + if isinstance(func, Name): + return func.name + if isinstance(func, Attribute): + return func.attrname + return "" # not supported yet + @classmethod def extract_call_by_name(cls, call: Call, name: str) -> Call | None: """Given a call-chain, extract its sub-call by method name (if it has one)""" diff --git a/tests/unit/source_code/linters/test_directfs.py b/tests/unit/source_code/linters/test_directfs.py index 7fed86d23a..fdb8ddf8db 100644 --- a/tests/unit/source_code/linters/test_directfs.py +++ b/tests/unit/source_code/linters/test_directfs.py @@ -51,8 +51,11 @@ def test_detects_dfsa_paths(code, expected): "code, expected", [ ("load_data('/dbfs/mnt/data')", 0), - ("""with open('/dbfs/mnt/data') as f: - f.read()""", 1), + ( + """with open('/dbfs/mnt/data') as f: + f.read()""", + 1, + ), ("dbutils.fs('/data')", 1), ("dbutils.fs('/dbfs/mnt/data', '/data')", 2), ("# load_data('/dbfs/mnt/data', '/data')", 0), diff --git a/tests/unit/source_code/python/test_python_ast.py b/tests/unit/source_code/python/test_python_ast.py index 460944bb43..d6e6ae54c6 100644 --- a/tests/unit/source_code/python/test_python_ast.py +++ b/tests/unit/source_code/python/test_python_ast.py @@ -236,23 +236,29 @@ def test_counts_lines(source: str, line_count: int): @pytest.mark.parametrize( - "source, name, is_builtin", [ + "source, name, is_builtin", + [ ("x = open()", "open", True), ("import datetime; x = datetime.datetime.now()", "now", True), ("import stuff; x = stuff()", "stuff", False), - ("""def stuff(): + ( + """def stuff(): pass -x = stuff()""", "stuff", False), - ]) +x = stuff()""", + "stuff", + False, + ), + ], +) def test_is_builtin(source, name, is_builtin): tree = Tree.normalize_and_parse(source) nodes = list(tree.node.get_children()) for node in nodes: if isinstance(node, Assign): - tree = Tree(node.value) - func_name = tree.get_call_name() + call = node.value + assert isinstance(call, Call) + func_name = TreeHelper.get_call_name(call) assert func_name == name - assert tree.is_builtin() == is_builtin + assert Tree(call).is_builtin() == is_builtin return assert False # could not locate call - From 9a3beb6b4fa9bfd75b582a8de2deec057d035d4f Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Wed, 2 Oct 2024 15:31:32 +0200 Subject: [PATCH 08/15] fix failing test --- src/databricks/labs/ucx/source_code/python/python_ast.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/python/python_ast.py b/src/databricks/labs/ucx/source_code/python/python_ast.py index c2bc95437e..0a8e739f37 100644 --- a/src/databricks/labs/ucx/source_code/python/python_ast.py +++ b/src/databricks/labs/ucx/source_code/python/python_ast.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys from abc import ABC import logging import re @@ -23,7 +24,6 @@ parse, Uninferable, ) -from astroid.modutils import BUILTIN_MODULES # type: ignore logger = logging.getLogger(__name__) missing_handlers: set[str] = set() @@ -295,8 +295,8 @@ def is_builtin(self) -> bool: builtins = cast(dict[str, Any], __builtins__) if name in builtins.keys(): return True - astroid_name = f"_{name}" - return BUILTIN_MODULES.get(astroid_name, None) is not None + names = sys.builtin_module_names + return f"_{name}" in names if isinstance(self._node, Call): return Tree(self._node.func).is_builtin() if isinstance(self._node, Attribute): From d1f5cde2e55f245a24f72cf369ea5a73c955bfa0 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Wed, 2 Oct 2024 15:52:44 +0200 Subject: [PATCH 09/15] fix failing test --- src/databricks/labs/ucx/source_code/python/python_ast.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/python/python_ast.py b/src/databricks/labs/ucx/source_code/python/python_ast.py index 0a8e739f37..09f9bdd9e2 100644 --- a/src/databricks/labs/ucx/source_code/python/python_ast.py +++ b/src/databricks/labs/ucx/source_code/python/python_ast.py @@ -1,11 +1,12 @@ from __future__ import annotations +import builtins import sys from abc import ABC import logging import re from collections.abc import Iterable -from typing import TypeVar, cast, Any +from typing import TypeVar, cast from astroid import ( # type: ignore Assign, @@ -292,11 +293,7 @@ def renumber_node(node: NodeNG, offset: int) -> None: def is_builtin(self) -> bool: if isinstance(self._node, Name): name = self._node.name - builtins = cast(dict[str, Any], __builtins__) - if name in builtins.keys(): - return True - names = sys.builtin_module_names - return f"_{name}" in names + return name in dir(builtins) or name in sys.stdlib_module_names or name in sys.builtin_module_names if isinstance(self._node, Call): return Tree(self._node.func).is_builtin() if isinstance(self._node, Attribute): From 7e509369db3955c30edf5519963dcfe087a9db63 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Thu, 3 Oct 2024 15:22:00 +0200 Subject: [PATCH 10/15] fix-merge-issue --- .../labs/ucx/source_code/linters/pyspark.py | 53 +++++++++++++++---- 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/linters/pyspark.py b/src/databricks/labs/ucx/source_code/linters/pyspark.py index f79dabdcea..689e07488b 100644 --- a/src/databricks/labs/ucx/source_code/linters/pyspark.py +++ b/src/databricks/labs/ucx/source_code/linters/pyspark.py @@ -52,12 +52,15 @@ def matches(self, node: NodeNG): @abstractmethod def lint( self, from_table: FromTableSqlLinter, index: TableMigrationIndex, session_state: CurrentSessionState, node: Call - ) -> Iterator[Advice]: - """raises Advices by linting the code""" + ) -> Iterable[Advice]: ... @abstractmethod - def apply(self, from_table: FromTableSqlLinter, index: TableMigrationIndex, node: Call) -> None: - """applies recommendations""" + def apply(self, from_table: FromTableSqlLinter, index: TableMigrationIndex, node: Call) -> None: ... + + @abstractmethod + def collect_tables( + self, from_table: FromTableSqlLinter, index: TableMigrationIndex, session_state: CurrentSessionState, node: Call + ) -> Iterable[UsedTable]: ... def _get_table_arg(self, node: Call): node_argc = len(node.args) @@ -91,28 +94,46 @@ def _check_call_context(self, node: Call) -> bool: @dataclass class SparkCallMatcher(_TableNameMatcher): - def lint( + def collect_tables( self, from_table: FromTableSqlLinter, index: TableMigrationIndex, session_state: CurrentSessionState, node: Call - ) -> Iterator[Advice]: + ) -> Iterable[UsedTable]: + for used_table in self._collect_tables(from_table, index, session_state, node): + if not used_table: + continue + yield used_table[1] + + def _collect_tables( + self, from_table: FromTableSqlLinter, index: TableMigrationIndex, session_state: CurrentSessionState, + node: Call + ) -> Iterable[tuple[str, UsedTable] | None]: table_arg = self._get_table_arg(node) if table_arg is None: return - table_name = table_arg.as_string().strip("'").strip('"') for inferred in InferredValue.infer_from_node(table_arg, session_state): if not inferred.is_inferred(): + yield None + continue + table_name = inferred.as_string().strip("'").strip('"') + info = UsedTable.parse(table_name, from_table.schema) + yield table_name, info + + def lint( + self, from_table: FromTableSqlLinter, index: TableMigrationIndex, session_state: CurrentSessionState, node: Call + ) -> Iterable[Advice]: + for used_table in self._collect_tables(from_table, index, session_state): + if not used_table: yield Advisory.from_node( code='cannot-autofix-table-reference', message=f"Can't migrate '{node.as_string()}' because its table name argument cannot be computed", node=node, ) continue - info = UsedTable.parse(inferred.as_string(), from_table.schema) - dst = self._find_dest(index, info) + dst = self._find_dest(index, used_table[1]) if dst is None: continue yield Deprecation.from_node( code='table-migrated-to-uc', - message=f"Table {table_name} is migrated to {dst.destination()} in Unity Catalog", + message=f"Table {used_table[0]} is migrated to {dst.destination()} in Unity Catalog", # SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159 node=node, ) @@ -155,6 +176,11 @@ def apply(self, from_table: FromTableSqlLinter, index: TableMigrationIndex, node # No transformations to apply return + def collect_tables( + self, from_table: FromTableSqlLinter, index: TableMigrationIndex, session_state: CurrentSessionState, node: Call + ) -> Iterable[UsedTable]: + return [] + T = TypeVar("T") @@ -194,6 +220,11 @@ def apply(self, from_table: FromTableSqlLinter, index: TableMigrationIndex, node # No transformations to apply return + def collect_tables( + self, from_table: FromTableSqlLinter, index: TableMigrationIndex, session_state: CurrentSessionState, node: Call + ) -> Iterable[UsedTable]: + return [] # we don't collect tables through this matcher + class SparkTableNameMatchers: @@ -364,7 +395,7 @@ def collect_tables_from_tree(self, tree: Tree) -> Iterable[TableInfoNode]: if matcher is None: continue assert isinstance(node, Call) - yield from matcher.lint(self._from_table, self._index, self._session_state, node) + yield from matcher.collect_tables(self._from_table, self._index, self._session_state, node) class _SparkSqlAnalyzer: From 573bf4fcbbd2940d956fefa10f034b3786818a4f Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Thu, 3 Oct 2024 15:22:15 +0200 Subject: [PATCH 11/15] fix missing property --- src/databricks/labs/ucx/contexts/application.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/databricks/labs/ucx/contexts/application.py b/src/databricks/labs/ucx/contexts/application.py index 95944a3d2a..cef7d05642 100644 --- a/src/databricks/labs/ucx/contexts/application.py +++ b/src/databricks/labs/ucx/contexts/application.py @@ -55,6 +55,7 @@ from databricks.labs.ucx.source_code.known import KnownList from databricks.labs.ucx.source_code.queries import QueryLinter from databricks.labs.ucx.source_code.redash import Redash +from databricks.labs.ucx.source_code.used_table import UsedTablesCrawler from databricks.labs.ucx.workspace_access import generic, redash from databricks.labs.ucx.workspace_access.groups import GroupManager from databricks.labs.ucx.workspace_access.manager import PermissionManager @@ -429,6 +430,7 @@ def workflow_linter(self): self.path_lookup, TableMigrationIndex([]), # TODO: bring back self.tables_migrator.index() self.directfs_access_crawler_for_paths, + self.used_tables_crawler_for_paths, self.config.include_job_ids, ) @@ -449,6 +451,14 @@ def directfs_access_crawler_for_paths(self): def directfs_access_crawler_for_queries(self): return DirectFsAccessCrawler.for_queries(self.sql_backend, self.inventory_database) + @cached_property + def used_tables_crawler_for_paths(self): + return UsedTablesCrawler.for_paths(self.sql_backend, self.inventory_database) + + @cached_property + def used_tables_crawler_for_queries(self): + return UsedTablesCrawler.for_queries(self.sql_backend, self.inventory_database) + @cached_property def redash(self): return Redash( From ca3c6ce6e9a1e7402b483f6999bedf112dd2bc88 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Thu, 3 Oct 2024 15:25:03 +0200 Subject: [PATCH 12/15] formatting --- src/databricks/labs/ucx/source_code/linters/pyspark.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/linters/pyspark.py b/src/databricks/labs/ucx/source_code/linters/pyspark.py index 689e07488b..943fc45810 100644 --- a/src/databricks/labs/ucx/source_code/linters/pyspark.py +++ b/src/databricks/labs/ucx/source_code/linters/pyspark.py @@ -103,9 +103,8 @@ def collect_tables( yield used_table[1] def _collect_tables( - self, from_table: FromTableSqlLinter, index: TableMigrationIndex, session_state: CurrentSessionState, - node: Call - ) -> Iterable[tuple[str, UsedTable] | None]: + self, from_table: FromTableSqlLinter, index: TableMigrationIndex, session_state: CurrentSessionState, node: Call + ) -> Iterable[tuple[str, UsedTable] | None]: table_arg = self._get_table_arg(node) if table_arg is None: return From 74cb5f990cb8bbbcf4ef9e7610432ff0c844b2dc Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Thu, 3 Oct 2024 15:27:21 +0200 Subject: [PATCH 13/15] fix missing arg --- src/databricks/labs/ucx/source_code/linters/pyspark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/labs/ucx/source_code/linters/pyspark.py b/src/databricks/labs/ucx/source_code/linters/pyspark.py index 943fc45810..709a344b97 100644 --- a/src/databricks/labs/ucx/source_code/linters/pyspark.py +++ b/src/databricks/labs/ucx/source_code/linters/pyspark.py @@ -119,7 +119,7 @@ def _collect_tables( def lint( self, from_table: FromTableSqlLinter, index: TableMigrationIndex, session_state: CurrentSessionState, node: Call ) -> Iterable[Advice]: - for used_table in self._collect_tables(from_table, index, session_state): + for used_table in self._collect_tables(from_table, index, session_state, node): if not used_table: yield Advisory.from_node( code='cannot-autofix-table-reference', From 4ccbae64c8b86e56962fcf17bf4695052650f7d0 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Thu, 3 Oct 2024 15:31:33 +0200 Subject: [PATCH 14/15] formatting --- src/databricks/labs/ucx/source_code/linters/pyspark.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/linters/pyspark.py b/src/databricks/labs/ucx/source_code/linters/pyspark.py index 709a344b97..f82b2faac9 100644 --- a/src/databricks/labs/ucx/source_code/linters/pyspark.py +++ b/src/databricks/labs/ucx/source_code/linters/pyspark.py @@ -97,13 +97,13 @@ class SparkCallMatcher(_TableNameMatcher): def collect_tables( self, from_table: FromTableSqlLinter, index: TableMigrationIndex, session_state: CurrentSessionState, node: Call ) -> Iterable[UsedTable]: - for used_table in self._collect_tables(from_table, index, session_state, node): + for used_table in self._collect_tables(from_table, session_state, node): if not used_table: continue yield used_table[1] def _collect_tables( - self, from_table: FromTableSqlLinter, index: TableMigrationIndex, session_state: CurrentSessionState, node: Call + self, from_table: FromTableSqlLinter, session_state: CurrentSessionState, node: Call ) -> Iterable[tuple[str, UsedTable] | None]: table_arg = self._get_table_arg(node) if table_arg is None: @@ -119,7 +119,7 @@ def _collect_tables( def lint( self, from_table: FromTableSqlLinter, index: TableMigrationIndex, session_state: CurrentSessionState, node: Call ) -> Iterable[Advice]: - for used_table in self._collect_tables(from_table, index, session_state, node): + for used_table in self._collect_tables(from_table, session_state, node): if not used_table: yield Advisory.from_node( code='cannot-autofix-table-reference', From 56b2ce0c1ed1da61691d74471c90856e5ab52b29 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Thu, 3 Oct 2024 16:36:42 +0200 Subject: [PATCH 15/15] more tests --- .../source_code/python/test_python_ast.py | 81 +++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/tests/unit/source_code/python/test_python_ast.py b/tests/unit/source_code/python/test_python_ast.py index 7d9f27e73e..f5c553c823 100644 --- a/tests/unit/source_code/python/test_python_ast.py +++ b/tests/unit/source_code/python/test_python_ast.py @@ -233,3 +233,84 @@ def test_renumbers_negatively(): def test_counts_lines(source: str, line_count: int): tree = Tree.normalize_and_parse(source) assert tree.line_count() == line_count + + +@pytest.mark.parametrize( + "source, name, is_builtin", + [ + ("x = open()", "open", True), + ("import datetime; x = datetime.datetime.now()", "now", True), + ("import stuff; x = stuff()", "stuff", False), + ( + """def stuff(): + pass +x = stuff()""", + "stuff", + False, + ), + ], +) +def test_is_builtin(source, name, is_builtin): + tree = Tree.normalize_and_parse(source) + nodes = list(tree.node.get_children()) + for node in nodes: + if isinstance(node, Assign): + call = node.value + assert isinstance(call, Call) + func_name = TreeHelper.get_call_name(call) + assert func_name == name + assert Tree(call).is_builtin() == is_builtin + return + assert False # could not locate call + + +def test_first_statement_is_none(): + node = Const("xyz") + assert not Tree(node).first_statement() + + +def test_repr_is_truncated(): + assert len(repr(Tree(Const("xyz")))) <= (32 + len("...") + len("")) + + +def test_append_tree_fails(): + with pytest.raises(NotImplementedError): + Tree(Const("xyz")).append_tree(Tree(Const("xyz"))) + + +def test_append_node_fails(): + with pytest.raises(NotImplementedError): + Tree(Const("xyz")).append_nodes([]) + + +def test_nodes_between_fails(): + with pytest.raises(NotImplementedError): + Tree(Const("xyz")).nodes_between(0, 100) + + +def test_has_global_fails(): + assert not Tree.new_module().has_global("xyz") + + +def test_append_globals_fails(): + with pytest.raises(NotImplementedError): + Tree(Const("xyz")).append_globals({}) + + +def test_globals_between_fails(): + with pytest.raises(NotImplementedError): + Tree(Const("xyz")).line_count() + + +def test_line_count_fails(): + with pytest.raises(NotImplementedError): + Tree(Const("xyz")).globals_between(0, 100) + + +def test_renumber_fails(): + with pytest.raises(NotImplementedError): + Tree(Const("xyz")).renumber(100) + + +def test_const_is_not_from_module(): + assert not Tree(Const("xyz")).is_from_module("spark")