Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change logic of direct filesystem access linting #2766

Merged
merged 22 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 14 additions & 52 deletions src/databricks/labs/ucx/source_code/linters/directfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC
from collections.abc import Iterable

from astroid import Attribute, Call, Const, InferenceError, JoinedStr, Name, NodeNG # type: ignore
from astroid import Call, InferenceError, NodeNG # type: ignore
from sqlglot import Expression as SqlExpression, parse as parse_sql, ParseError as SqlParseError
from sqlglot.expressions import Alter, Create, Delete, Drop, Identifier, Insert, Literal, Select

Expand Down Expand Up @@ -72,43 +72,36 @@ class _DetectDirectFsAccessVisitor(TreeVisitor):
def __init__(self, session_state: CurrentSessionState, prevent_spark_duplicates: bool) -> None:
self._session_state = session_state
self._directfs_nodes: list[DirectFsAccessNode] = []
self._reported_locations: set[tuple[int, int]] = set()
self._prevent_spark_duplicates = prevent_spark_duplicates

def visit_call(self, node: Call):
for arg in node.args:
self._visit_arg(arg)
self._visit_arg(node, arg)

def _visit_arg(self, arg: NodeNG):
def _visit_arg(self, call: Call, arg: NodeNG):
try:
for inferred in InferredValue.infer_from_node(arg, self._session_state):
if not inferred.is_inferred():
logger.debug(f"Could not infer value of {arg.as_string()}")
continue
self._check_str_constant(arg, inferred)
self._check_str_arg(call, arg, inferred)
except InferenceError as e:
logger.debug(f"Could not infer value of {arg.as_string()}", exc_info=e)

def visit_const(self, node: Const):
# Constant strings yield Advisories
if isinstance(node.value, str):
self._check_str_constant(node, InferredValue([node]))

def _check_str_constant(self, source_node: NodeNG, inferred: InferredValue):
if self._already_reported(source_node, inferred):
return
# don't report on JoinedStr fragments
if isinstance(source_node.parent, JoinedStr):
return
def _check_str_arg(self, call_node: Call, arg_node: NodeNG, inferred: InferredValue):
value = inferred.as_string()
for pattern in DIRECT_FS_ACCESS_PATTERNS:
if not pattern.matches(value):
continue
# avoid false positives with relative URLs
if self._is_http_call_parameter(source_node):
# don't capture calls not originating from spark or dbutils
JCZuurmond marked this conversation as resolved.
Show resolved Hide resolved
is_from_spark = False
is_from_db_utils = Tree(call_node).is_from_module("dbutils")
if not is_from_db_utils:
is_from_spark = Tree(call_node).is_from_module("spark")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can is_from_spark only be True when is_from_db_utils is False. I would expect those variables to be independent. If that it is the case, this dependency introduced on line 99 creates a bug on line 104

Copy link
Contributor Author

@ericvergnaud ericvergnaud Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are independent, but if either is False we never reach line 104

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean:

They are independent, but if both are False we never reach line 104

That is what I read: the "and" in the code with negations on both sides

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup sorry

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose here is to minimize calls to is_from_module because they're not cheap

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose here is to minimize calls to is_from_module because they're not cheap

Could you add that as a comment? It is fine with me, just want to be sure the logic is correct though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if not is_from_db_utils and not is_from_spark:
return
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
is_from_spark = False
is_from_db_utils = Tree(call_node).is_from_module("dbutils")
if not is_from_db_utils:
is_from_spark = Tree(call_node).is_from_module("spark")
if not is_from_db_utils and not is_from_spark:
return
tree = Tree(call_node)
is_from_spark = tree.is_from_module("spark")
is_from_dbutils = tree.is_from_module("dbutils")
if not (is_from_dbutils or is_from_spark):
return

i wonder if this logic is more readable

Copy link
Contributor Author

@ericvergnaud ericvergnaud Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is but it requires making 2 calls to is_from_module in all cases. Changed.

# avoid duplicate advices that are reported by SparkSqlPyLinter
if self._prevent_spark_duplicates and Tree(source_node).is_from_module("spark"):
if self._prevent_spark_duplicates and is_from_spark:
return
# since we're normally filtering out spark calls, we're dealing with dfsas we know little about
# notably we don't know is_read or is_write
Expand All @@ -117,39 +110,8 @@ def _check_str_constant(self, source_node: NodeNG, inferred: InferredValue):
is_read=True,
is_write=False,
)
self._directfs_nodes.append(DirectFsAccessNode(dfsa, source_node))
self._reported_locations.add((source_node.lineno, source_node.col_offset))

@classmethod
def _is_http_call_parameter(cls, source_node: NodeNG):
if not isinstance(source_node.parent, Call):
return False
# for now we only cater for ws.api_client.do
return cls._is_ws_api_client_do_call(source_node)

@classmethod
def _is_ws_api_client_do_call(cls, source_node: NodeNG):
assert isinstance(source_node.parent, Call)
func = source_node.parent.func
if not isinstance(func, Attribute) or func.attrname != "do":
return False
expr = func.expr
if not isinstance(expr, Attribute) or expr.attrname != "api_client":
return False
expr = expr.expr
if not isinstance(expr, Name):
return False
for value in InferredValue.infer_from_node(expr):
if not value.is_inferred():
continue
for node in value.nodes:
return Tree(node).is_instance_of("WorkspaceClient")
# at this point is seems safer to assume that expr.expr is a workspace than the opposite
return True

def _already_reported(self, source_node: NodeNG, inferred: InferredValue):
all_nodes = [source_node] + inferred.nodes
return any((node.lineno, node.col_offset) in self._reported_locations for node in all_nodes)
self._directfs_nodes.append(DirectFsAccessNode(dfsa, arg_node))
return

@property
def directfs_nodes(self):
Expand Down
15 changes: 9 additions & 6 deletions tests/unit/source_code/linters/test_directfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@ def test_matches_dfsa_pattern(path, matches):
"code, expected",
[
('SOME_CONSTANT = "not a file system path"', 0),
('SOME_CONSTANT = ("/dbfs/mnt", "dbfs:/", "/mnt/")', 3),
('SOME_CONSTANT = ("/dbfs/mnt", "dbfs:/", "/mnt/")', 0),
('# "/dbfs/mnt"', 0),
('SOME_CONSTANT = "/dbfs/mnt"', 1),
('SOME_CONSTANT = "/dbfs/mnt"; load_data(SOME_CONSTANT)', 1),
('SOME_CONSTANT = "/dbfs/mnt"', 0),
('SOME_CONSTANT = "/dbfs/mnt"; load_data(SOME_CONSTANT)', 0),
('SOME_CONSTANT = "/dbfs/mnt"; spark.table(SOME_CONSTANT)', 1),
('SOME_CONSTANT = ("/dbfs/mnt", "dbfs:/", "/mnt/"); [dbutils.fs(path) for path in SOME_CONSTANT]', 3),
('SOME_CONSTANT = 42; load_data(SOME_CONSTANT)', 0),
('SOME_CONSTANT = "/dbfs/mnt"; dbutils.fs(SOME_CONSTANT)', 1),
],
)
def test_detects_dfsa_paths(code, expected):
Expand All @@ -47,9 +50,9 @@ def test_detects_dfsa_paths(code, expected):
@pytest.mark.parametrize(
"code, expected",
[
("load_data('/dbfs/mnt/data')", 1),
("load_data('/data')", 1),
("load_data('/dbfs/mnt/data', '/data')", 2),
("load_data('/dbfs/mnt/data')", 0),
("dbutils.fs('/data')", 1),
("dbutils.fs('/dbfs/mnt/data', '/data')", 2),
("# load_data('/dbfs/mnt/data', '/data')", 0),
('spark.read.parquet("/mnt/foo/bar")', 1),
('spark.read.parquet("dbfs:/mnt/foo/bar")', 1),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,12 @@
-- COMMAND ----------
-- DBTITLE 1,A Python cell that references DBFS
-- MAGIC %python
-- ucx[direct-filesystem-access:+1:7:+1:18] The use of direct filesystem references is deprecated: dbfs:/...
-- MAGIC DBFS = "dbfs:/..."
-- ucx[direct-filesystem-access:+1:7:+1:18] The use of direct filesystem references is deprecated: /dbfs/mnt
-- MAGIC DBFS = "/dbfs/mnt"
-- ucx[direct-filesystem-access:+1:7:+1:14] The use of direct filesystem references is deprecated: /mnt/
-- MAGIC DBFS = "/mnt/"
-- ucx[direct-filesystem-access:+1:7:+1:18] The use of direct filesystem references is deprecated: dbfs:/...
-- MAGIC DBFS = "dbfs:/..."
-- ucx[direct-filesystem-access:+1:10:+1:26] The use of direct filesystem references is deprecated: /dbfs/mnt/data
-- MAGIC load_data('/dbfs/mnt/data')
-- ucx[direct-filesystem-access:+1:10:+1:17] The use of direct filesystem references is deprecated: /data
-- MAGIC load_data('/data')
-- ucx[direct-filesystem-access:+2:10:+2:26] The use of direct filesystem references is deprecated: /dbfs/mnt/data
-- ucx[direct-filesystem-access:+1:28:+1:35] The use of direct filesystem references is deprecated: /data
-- MAGIC load_data('/dbfs/mnt/data', '/data')
-- MAGIC # load_data('/dbfs/mnt/data', '/data')
-- ucx[direct-filesystem-access:+1:0:+1:34] The use of direct filesystem references is deprecated: /mnt/foo/bar
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# ucx[direct-filesystem-access:+1:6:+1:26] The use of direct filesystem references is deprecated: dbfs:/mnt/foo/bar1
DBFS1="dbfs:/mnt/foo/bar1"
# ucx[direct-filesystem-access:+1:16:+1:36] The use of direct filesystem references is deprecated: dbfs:/mnt/foo/bar2
systems=[DBFS1, "dbfs:/mnt/foo/bar2"]
for system in systems:
# ucx[direct-filesystem-access:+2:4:+2:30] The use of direct filesystem references is deprecated: dbfs:/mnt/foo/bar1
Expand Down