Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added more value inference for dbutils.notebook.run(...) #1860

Merged
merged 19 commits into from
Jun 8, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions src/databricks/labs/ucx/source_code/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,19 +183,31 @@ def build_graph_from_python_source(self, python_code: str) -> list[DependencyPro
def _process_node(self, base_node: NodeBase):
if isinstance(base_node, SysPathChange):
self._mutate_path_lookup(base_node)
return
if isinstance(base_node, NotebookRunCall):
strpath = base_node.get_notebook_path()
if strpath is None:
yield DependencyProblem('dependency-not-constant', "Can't check dependency not provided as a constant")
else:
yield from self.register_notebook(Path(strpath))
yield from self._register_notebook(base_node)
return
if isinstance(base_node, ImportSource):
prefix = ""
if isinstance(base_node.node, ImportFrom) and base_node.node.level is not None:
prefix = "." * base_node.node.level
name = base_node.name or ""
yield from self.register_import(prefix + name)

def _register_notebook(self, base_node: NotebookRunCall):
paths = base_node.get_notebook_paths()
asserted = False
for path in paths:
if isinstance(path, str):
yield from self.register_notebook(Path(path))
continue
if not asserted:
asserted = True
yield DependencyProblem(
'dependency-cannot-compute',
f"Can't check dependency from {base_node.node.as_string()} because the expression cannot be computed",
)

def _mutate_path_lookup(self, change: SysPathChange):
path = Path(change.path)
if not path.is_absolute():
Expand Down
37 changes: 25 additions & 12 deletions src/databricks/labs/ucx/source_code/linters/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Attribute,
Call,
Const,
InferenceError,
Import,
ImportFrom,
Name,
Expand Down Expand Up @@ -83,12 +84,23 @@ class NotebookRunCall(NodeBase):
def __init__(self, node: Call):
super().__init__(node)

def get_notebook_path(self) -> str | None:
node = DbutilsLinter.get_dbutils_notebook_run_path_arg(cast(Call, self.node))
inferred = next(node.infer(), None)
if isinstance(inferred, Const):
return inferred.value.strip().strip("'").strip('"')
return None
def get_notebook_paths(self) -> list[str | None]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to change the signature?.. dbutils.notebook.run() can have at most two arguments - path and parameters - it can't have multiple paths.

Copy link
Contributor Author

@ericvergnaud ericvergnaud Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need this because astroid is clever enough to return multiple inferred nodes in a scenario such as:

paths = ["p1", "p2"]
for path in paths:
    dbutils.notebook.run(path)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, interesting. Please add it as a code comment, so that the next time reading this code won't catch by surprise

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

node = DbutilsLinter.get_dbutils_notebook_run_path_arg(self.node)
try:
return self._get_notebook_paths(node.infer())
except InferenceError:
logger.debug(f"Can't infer value(s) of {node.as_string()}")
return [None]

@classmethod
def _get_notebook_paths(cls, nodes: Iterable[NodeNG]) -> list[str | None]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _get_notebook_paths(cls, nodes: Iterable[NodeNG]) -> list[str | None]:
def _get_notebook_paths(cls, nodes: Iterable[NodeNG]) -> list[str]:

let's avoid nullability

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

paths: list[str | None] = []
for node in nodes:
if isinstance(node, Const):
nfx marked this conversation as resolved.
Show resolved Hide resolved
paths.append(node.as_string().strip("'").strip('"'))
continue
paths.append(None)
return paths


T = TypeVar("T", bound=Callable)
Expand All @@ -104,19 +116,20 @@ def lint(self, code: str) -> Iterable[Advice]:
@classmethod
def _convert_dbutils_notebook_run_to_advice(cls, node: NodeNG) -> Advisory:
assert isinstance(node, Call)
path = cls.get_dbutils_notebook_run_path_arg(node)
if isinstance(path, Const):
call = NotebookRunCall(cast(Call, node))
paths = call.get_notebook_paths()
if None in paths:
return Advisory(
'dbutils-notebook-run-literal',
"Call to 'dbutils.notebook.run' will be migrated automatically",
'dbutils-notebook-run-dynamic',
"Path for 'dbutils.notebook.run' is too complex and requires adjusting the notebook path(s)",
node.lineno,
node.col_offset,
node.end_lineno or 0,
node.end_col_offset or 0,
)
return Advisory(
'dbutils-notebook-run-dynamic',
"Path for 'dbutils.notebook.run' is not a constant and requires adjusting the notebook path",
'dbutils-notebook-run-literal',
"Call to 'dbutils.notebook.run' will be migrated automatically",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we won't be migrating notebook.run() logic.

Copy link
Contributor Author

@ericvergnaud ericvergnaud Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what should be the message (the above was existing) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we drop this advice altogether ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

node.lineno,
node.col_offset,
node.end_lineno or 0,
Expand Down
15 changes: 9 additions & 6 deletions src/databricks/labs/ucx/source_code/notebooks/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
SourceContainer,
)
from databricks.labs.ucx.source_code.notebooks.cells import CellLanguage
from databricks.labs.ucx.source_code.notebooks.sources import Notebook
from databricks.labs.ucx.source_code.notebooks.sources import Notebook, SUPPORTED_EXTENSION_LANGUAGES
from databricks.labs.ucx.source_code.path_lookup import PathLookup

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -55,17 +55,20 @@ def load_dependency(self, path_lookup: PathLookup, dependency: Dependency) -> So
except NotFound:
logger.warning(f"Could not read notebook from workspace: {absolute_path}")
return None
language = self._detect_language(content)
language = self.detect_language(absolute_path, content)
if not language:
logger.warning(f"Could not detect language for {absolute_path}")
return None
return Notebook.parse(absolute_path, content, language)

@staticmethod
def _detect_language(content: str):
for language in CellLanguage:
if content.startswith(language.file_magic_header):
return language.language
def detect_language(path: Path, content: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this method public?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for testing

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a slippery slope to expose methods public just for testing without a significant need. This case doesn't justify this need and could be tested through other public methods.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's nothing dangerous about this method, so not sure why this one is slippery ? or maybe we should allow access to private methods in unit testing, such that we can actually write unit tests rather than slower and complex end-to-end tests ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slippery slope is that if we allow it for trivial cases, then inexperienced devs would expose inner workings of classes as Public API, resulting in a more fragile system. This codebase was in that state 6 months ago and we're not going back there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noted

language = SUPPORTED_EXTENSION_LANGUAGES.get(path.suffix, None)
if language:
return language
for cell_language in CellLanguage:
if content.startswith(cell_language.file_magic_header):
return cell_language.language
return None

@staticmethod
Expand Down
123 changes: 123 additions & 0 deletions tests/unit/source_code/linters/test_python_ast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import functools
import operator

import pytest
from astroid import Attribute, Call, Const, Expr # type: ignore

from databricks.labs.ucx.source_code.linters.imports import DbutilsLinter
from databricks.labs.ucx.source_code.linters.python_ast import Tree


def test_extract_call_by_name():
tree = Tree.parse("o.m1().m2().m3()")
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.extract_call_by_name(stmt.value, "m2")
assert isinstance(act, Call)
assert isinstance(act.func, Attribute)
assert act.func.attrname == "m2"


def test_extract_call_by_name_none():
tree = Tree.parse("o.m1().m2().m3()")
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.extract_call_by_name(stmt.value, "m5000")
assert act is None


@pytest.mark.parametrize(
"code, arg_index, arg_name, expected",
[
("o.m1()", 1, "second", None),
("o.m1(3)", 1, "second", None),
("o.m1(first=3)", 1, "second", None),
("o.m1(4, 3)", None, None, None),
("o.m1(4, 3)", None, "second", None),
("o.m1(4, 3)", 1, "second", 3),
("o.m1(4, 3)", 1, None, 3),
("o.m1(first=4, second=3)", 1, "second", 3),
("o.m1(second=3, first=4)", 1, "second", 3),
("o.m1(second=3, first=4)", None, "second", 3),
("o.m1(second=3)", 1, "second", 3),
("o.m1(4, 3, 2)", 1, "second", 3),
],
)
def test_linter_gets_arg(code, arg_index, arg_name, expected):
tree = Tree.parse(code)
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.get_arg(stmt.value, arg_index, arg_name)
if expected is None:
assert act is None
else:
assert isinstance(act, Const)
assert act.value == expected


@pytest.mark.parametrize(
"code, expected",
[
("o.m1()", 0),
("o.m1(3)", 1),
("o.m1(first=3)", 1),
("o.m1(3, 3)", 2),
("o.m1(first=3, second=3)", 2),
("o.m1(3, second=3)", 2),
("o.m1(3, *b, **c, second=3)", 4),
],
)
def test_args_count(code, expected):
tree = Tree.parse(code)
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.args_count(stmt.value)
assert act == expected


def test_tree_walks_nodes_once():
nodes = set()
count = 0
tree = Tree.parse("o.m1().m2().m3()")
for node in tree.walk():
nodes.add(node)
count += 1
assert len(nodes) == count


@pytest.mark.parametrize(
nfx marked this conversation as resolved.
Show resolved Hide resolved
"code, expected",
[
(
"""
name = "xyz"
dbutils.notebook.run(name)
""",
["xyz"],
),
(
"""
name = "xyz" + "-" + "abc"
dbutils.notebook.run(name)
""",
["xyz-abc"],
),
(
"""
names = ["abc", "xyz"]
for name in names:
dbutils.notebook.run(name)
nfx marked this conversation as resolved.
Show resolved Hide resolved
""",
["abc", "xyz"],
),
],
)
def test_infers_dbutils_notebook_run_dynamic_value(code, expected):
tree = Tree.parse(code)
calls = DbutilsLinter.list_dbutils_notebook_run_calls(tree)
actual = functools.reduce(operator.iconcat, list(call.get_notebook_paths() for call in calls), [])
assert expected == actual
103 changes: 0 additions & 103 deletions tests/unit/source_code/linters/test_python_imports.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from __future__ import annotations


import pytest
from astroid import Attribute, Call, Const, Expr # type: ignore
from databricks.labs.ucx.source_code.graph import DependencyProblem

from databricks.labs.ucx.source_code.linters.imports import DbutilsLinter, ImportSource, SysPathChange
Expand Down Expand Up @@ -135,103 +132,3 @@ def test_linter_returns_appended_relative_paths_with_os_path_abspath_alias():
tree = Tree.parse(code)
appended = SysPathChange.extract_from_tree(tree)
assert "relative_path" in [p.path for p in appended]


def test_extract_call_by_name():
tree = Tree.parse("o.m1().m2().m3()")
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.extract_call_by_name(stmt.value, "m2")
assert isinstance(act, Call)
assert isinstance(act.func, Attribute)
assert act.func.attrname == "m2"


def test_extract_call_by_name_none():
tree = Tree.parse("o.m1().m2().m3()")
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.extract_call_by_name(stmt.value, "m5000")
assert act is None


@pytest.mark.parametrize(
"code, arg_index, arg_name, expected",
[
("o.m1()", 1, "second", None),
("o.m1(3)", 1, "second", None),
("o.m1(first=3)", 1, "second", None),
("o.m1(4, 3)", None, None, None),
("o.m1(4, 3)", None, "second", None),
("o.m1(4, 3)", 1, "second", 3),
("o.m1(4, 3)", 1, None, 3),
("o.m1(first=4, second=3)", 1, "second", 3),
("o.m1(second=3, first=4)", 1, "second", 3),
("o.m1(second=3, first=4)", None, "second", 3),
("o.m1(second=3)", 1, "second", 3),
("o.m1(4, 3, 2)", 1, "second", 3),
],
)
def test_linter_gets_arg(code, arg_index, arg_name, expected):
tree = Tree.parse(code)
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.get_arg(stmt.value, arg_index, arg_name)
if expected is None:
assert act is None
else:
assert isinstance(act, Const)
assert act.value == expected


@pytest.mark.parametrize(
"code, expected",
[
("o.m1()", 0),
("o.m1(3)", 1),
("o.m1(first=3)", 1),
("o.m1(3, 3)", 2),
("o.m1(first=3, second=3)", 2),
("o.m1(3, second=3)", 2),
("o.m1(3, *b, **c, second=3)", 4),
],
)
def test_args_count(code, expected):
tree = Tree.parse(code)
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.args_count(stmt.value)
assert act == expected


@pytest.mark.parametrize(
"code, expected",
[
(
"""
name = "xyz"
dbutils.notebook.run(name)
""",
"xyz",
)
],
)
def test_infers_string_variable_value(code, expected):
tree = Tree.parse(code)
calls = DbutilsLinter.list_dbutils_notebook_run_calls(tree)
actual = list(call.get_notebook_path() for call in calls)
assert [expected] == actual


def test_tree_walker_walks_nodes_once():
nodes = set()
count = 0
tree = Tree.parse("o.m1().m2().m3()")
for node in tree.walk():
nodes.add(node)
count += 1
assert len(nodes) == count
12 changes: 12 additions & 0 deletions tests/unit/source_code/notebooks/test_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from pathlib import Path

from databricks.labs.ucx.source_code.notebooks.loaders import NotebookLoader
from databricks.sdk.service.workspace import Language


def test_detects_language():
assert NotebookLoader.detect_language(Path("hi.py"), "stuff") == Language.PYTHON
nfx marked this conversation as resolved.
Show resolved Hide resolved
assert NotebookLoader.detect_language(Path("hi.sql"), "stuff") == Language.SQL
assert NotebookLoader.detect_language(Path("hi"), "# Databricks notebook source") == Language.PYTHON
assert NotebookLoader.detect_language(Path("hi"), "-- Databricks notebook source") == Language.SQL
assert not NotebookLoader.detect_language(Path("hi"), "stuff")
Loading
Loading