Skip to content

Commit

Permalink
Added more value inference for dbutils.notebook.run(...) (#1860)
Browse files Browse the repository at this point in the history
  • Loading branch information
ericvergnaud authored Jun 8, 2024
1 parent b19c848 commit 879a5b4
Show file tree
Hide file tree
Showing 10 changed files with 291 additions and 193 deletions.
18 changes: 13 additions & 5 deletions src/databricks/labs/ucx/source_code/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,19 +191,27 @@ def build_graph_from_python_source(self, python_code: str) -> list[DependencyPro
def _process_node(self, base_node: NodeBase):
if isinstance(base_node, SysPathChange):
self._mutate_path_lookup(base_node)
return
if isinstance(base_node, NotebookRunCall):
strpath = base_node.get_notebook_path()
if strpath is None:
yield DependencyProblem('dependency-not-constant', "Can't check dependency not provided as a constant")
else:
yield from self.register_notebook(Path(strpath))
yield from self._register_notebook(base_node)
return
if isinstance(base_node, ImportSource):
prefix = ""
if isinstance(base_node.node, ImportFrom) and base_node.node.level is not None:
prefix = "." * base_node.node.level
name = base_node.name or ""
yield from self.register_import(prefix + name)

def _register_notebook(self, base_node: NotebookRunCall):
has_unresolved, paths = base_node.get_notebook_paths()
if has_unresolved:
yield DependencyProblem(
'dependency-cannot-compute',
f"Can't check dependency from {base_node.node.as_string()} because the expression cannot be computed",
)
for path in paths:
yield from self.register_notebook(Path(path))

def _mutate_path_lookup(self, change: SysPathChange):
path = Path(change.path)
if not path.is_absolute():
Expand Down
58 changes: 38 additions & 20 deletions src/databricks/labs/ucx/source_code/linters/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Attribute,
Call,
Const,
InferenceError,
Import,
ImportFrom,
Name,
Expand Down Expand Up @@ -81,36 +82,53 @@ class NotebookRunCall(NodeBase):
def __init__(self, node: Call):
super().__init__(node)

def get_notebook_path(self) -> str | None:
node = DbutilsLinter.get_dbutils_notebook_run_path_arg(cast(Call, self.node))
inferred = next(node.infer(), None)
if isinstance(inferred, Const):
return inferred.value.strip().strip("'").strip('"')
return None
def get_notebook_paths(self) -> tuple[bool, list[str]]:
"""we return multiple paths because astroid can infer them in scenarios such as:
paths = ["p1", "p2"]
for path in paths:
dbutils.notebook.run(path)
"""
node = DbutilsLinter.get_dbutils_notebook_run_path_arg(self.node)
try:
return self._get_notebook_paths(node.infer())
except InferenceError:
logger.debug(f"Can't infer value(s) of {node.as_string()}")
return True, []

@classmethod
def _get_notebook_paths(cls, nodes: Iterable[NodeNG]) -> tuple[bool, list[str]]:
has_unresolved = False
paths: list[str] = []
for node in nodes:
if isinstance(node, Const):
paths.append(node.as_string().strip("'").strip('"'))
continue
logger.debug(f"Can't compute {type(node).__name__}")
has_unresolved = True
return has_unresolved, paths


class DbutilsLinter(Linter):

def lint(self, code: str) -> Iterable[Advice]:
tree = Tree.parse(code)
nodes = self.list_dbutils_notebook_run_calls(tree)
return [self._convert_dbutils_notebook_run_to_advice(node.node) for node in nodes]
for node in nodes:
yield from self._raise_advice_if_unresolved(node.node)

@classmethod
def _convert_dbutils_notebook_run_to_advice(cls, node: NodeNG) -> Advisory:
def _raise_advice_if_unresolved(cls, node: NodeNG) -> Iterable[Advice]:
assert isinstance(node, Call)
path = cls.get_dbutils_notebook_run_path_arg(node)
if isinstance(path, Const):
return Advisory.from_node(
'dbutils-notebook-run-literal',
"Call to 'dbutils.notebook.run' will be migrated automatically",
node=node,
)
return Advisory.from_node(
'dbutils-notebook-run-dynamic',
"Path for 'dbutils.notebook.run' is not a constant and requires adjusting the notebook path",
node=node,
)
call = NotebookRunCall(cast(Call, node))
has_unresolved, _ = call.get_notebook_paths()
if has_unresolved:
yield from [
Advisory.from_node(
'dbutils-notebook-run-dynamic',
"Path for 'dbutils.notebook.run' cannot be computed and requires adjusting the notebook path(s)",
node=node,
)
]

@staticmethod
def get_dbutils_notebook_run_path_arg(node: Call):
Expand Down
15 changes: 9 additions & 6 deletions src/databricks/labs/ucx/source_code/notebooks/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
SourceContainer,
)
from databricks.labs.ucx.source_code.notebooks.cells import CellLanguage
from databricks.labs.ucx.source_code.notebooks.sources import Notebook
from databricks.labs.ucx.source_code.notebooks.sources import Notebook, SUPPORTED_EXTENSION_LANGUAGES
from databricks.labs.ucx.source_code.path_lookup import PathLookup

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -55,17 +55,20 @@ def load_dependency(self, path_lookup: PathLookup, dependency: Dependency) -> So
except NotFound:
logger.warning(f"Could not read notebook from workspace: {absolute_path}")
return None
language = self._detect_language(content)
language = self._detect_language(absolute_path, content)
if not language:
logger.warning(f"Could not detect language for {absolute_path}")
return None
return Notebook.parse(absolute_path, content, language)

@staticmethod
def _detect_language(content: str):
for language in CellLanguage:
if content.startswith(language.file_magic_header):
return language.language
def _detect_language(path: Path, content: str):
language = SUPPORTED_EXTENSION_LANGUAGES.get(path.suffix, None)
if language:
return language
for cell_language in CellLanguage:
if content.startswith(cell_language.file_magic_header):
return cell_language.language
return None

@staticmethod
Expand Down
85 changes: 85 additions & 0 deletions tests/unit/source_code/linters/test_python_ast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import pytest
from astroid import Attribute, Call, Const, Expr # type: ignore

from databricks.labs.ucx.source_code.linters.python_ast import Tree


def test_extract_call_by_name():
tree = Tree.parse("o.m1().m2().m3()")
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.extract_call_by_name(stmt.value, "m2")
assert isinstance(act, Call)
assert isinstance(act.func, Attribute)
assert act.func.attrname == "m2"


def test_extract_call_by_name_none():
tree = Tree.parse("o.m1().m2().m3()")
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.extract_call_by_name(stmt.value, "m5000")
assert act is None


@pytest.mark.parametrize(
"code, arg_index, arg_name, expected",
[
("o.m1()", 1, "second", None),
("o.m1(3)", 1, "second", None),
("o.m1(first=3)", 1, "second", None),
("o.m1(4, 3)", None, None, None),
("o.m1(4, 3)", None, "second", None),
("o.m1(4, 3)", 1, "second", 3),
("o.m1(4, 3)", 1, None, 3),
("o.m1(first=4, second=3)", 1, "second", 3),
("o.m1(second=3, first=4)", 1, "second", 3),
("o.m1(second=3, first=4)", None, "second", 3),
("o.m1(second=3)", 1, "second", 3),
("o.m1(4, 3, 2)", 1, "second", 3),
],
)
def test_linter_gets_arg(code, arg_index, arg_name, expected):
tree = Tree.parse(code)
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.get_arg(stmt.value, arg_index, arg_name)
if expected is None:
assert act is None
else:
assert isinstance(act, Const)
assert act.value == expected


@pytest.mark.parametrize(
"code, expected",
[
("o.m1()", 0),
("o.m1(3)", 1),
("o.m1(first=3)", 1),
("o.m1(3, 3)", 2),
("o.m1(first=3, second=3)", 2),
("o.m1(3, second=3)", 2),
("o.m1(3, *b, **c, second=3)", 4),
],
)
def test_args_count(code, expected):
tree = Tree.parse(code)
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.args_count(stmt.value)
assert act == expected


def test_tree_walks_nodes_once():
nodes = set()
count = 0
tree = Tree.parse("o.m1().m2().m3()")
for node in tree.walk():
nodes.add(node)
count += 1
assert len(nodes) == count
119 changes: 32 additions & 87 deletions tests/unit/source_code/linters/test_python_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


import pytest
from astroid import Attribute, Call, Const, Expr # type: ignore

from databricks.labs.ucx.source_code.graph import DependencyProblem

from databricks.labs.ucx.source_code.linters.imports import DbutilsLinter, ImportSource, SysPathChange
Expand Down Expand Up @@ -137,77 +137,6 @@ def test_linter_returns_appended_relative_paths_with_os_path_abspath_alias():
assert "relative_path" in [p.path for p in appended]


def test_extract_call_by_name():
tree = Tree.parse("o.m1().m2().m3()")
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.extract_call_by_name(stmt.value, "m2")
assert isinstance(act, Call)
assert isinstance(act.func, Attribute)
assert act.func.attrname == "m2"


def test_extract_call_by_name_none():
tree = Tree.parse("o.m1().m2().m3()")
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.extract_call_by_name(stmt.value, "m5000")
assert act is None


@pytest.mark.parametrize(
"code, arg_index, arg_name, expected",
[
("o.m1()", 1, "second", None),
("o.m1(3)", 1, "second", None),
("o.m1(first=3)", 1, "second", None),
("o.m1(4, 3)", None, None, None),
("o.m1(4, 3)", None, "second", None),
("o.m1(4, 3)", 1, "second", 3),
("o.m1(4, 3)", 1, None, 3),
("o.m1(first=4, second=3)", 1, "second", 3),
("o.m1(second=3, first=4)", 1, "second", 3),
("o.m1(second=3, first=4)", None, "second", 3),
("o.m1(second=3)", 1, "second", 3),
("o.m1(4, 3, 2)", 1, "second", 3),
],
)
def test_linter_gets_arg(code, arg_index, arg_name, expected):
tree = Tree.parse(code)
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.get_arg(stmt.value, arg_index, arg_name)
if expected is None:
assert act is None
else:
assert isinstance(act, Const)
assert act.value == expected


@pytest.mark.parametrize(
"code, expected",
[
("o.m1()", 0),
("o.m1(3)", 1),
("o.m1(first=3)", 1),
("o.m1(3, 3)", 2),
("o.m1(first=3, second=3)", 2),
("o.m1(3, second=3)", 2),
("o.m1(3, *b, **c, second=3)", 4),
],
)
def test_args_count(code, expected):
tree = Tree.parse(code)
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.args_count(stmt.value)
assert act == expected


@pytest.mark.parametrize(
"code, expected",
[
Expand All @@ -216,22 +145,38 @@ def test_args_count(code, expected):
name = "xyz"
dbutils.notebook.run(name)
""",
"xyz",
)
["xyz"],
),
(
"""
name = "xyz" + "-" + "abc"
dbutils.notebook.run(name)
""",
["xyz-abc"],
),
(
"""
names = ["abc", "xyz"]
for name in names:
dbutils.notebook.run(name)
""",
["abc", "xyz"],
),
(
"""
def foo(): return "bar"
name = foo()
dbutils.notebook.run(name)
""",
["bar"],
),
],
)
def test_infers_string_variable_value(code, expected):
def test_infers_dbutils_notebook_run_dynamic_value(code, expected):
tree = Tree.parse(code)
calls = DbutilsLinter.list_dbutils_notebook_run_calls(tree)
actual = list(call.get_notebook_path() for call in calls)
assert [expected] == actual


def test_tree_walker_walks_nodes_once():
nodes = set()
count = 0
tree = Tree.parse("o.m1().m2().m3()")
for node in tree.walk():
nodes.add(node)
count += 1
assert len(nodes) == count
all_paths: list[str] = []
for call in calls:
_, paths = call.get_notebook_paths()
all_paths.extend(paths)
assert all_paths == expected
19 changes: 19 additions & 0 deletions tests/unit/source_code/notebooks/test_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pathlib import Path

from databricks.labs.ucx.source_code.notebooks.loaders import NotebookLoader
from databricks.sdk.service.workspace import Language


def test_detects_language():

class NotebookLoaderForTesting(NotebookLoader):

@classmethod
def detect_language(cls, path: Path, content: str):
return cls._detect_language(path, content)

assert NotebookLoaderForTesting.detect_language(Path("hi.py"), "stuff") == Language.PYTHON
assert NotebookLoaderForTesting.detect_language(Path("hi.sql"), "stuff") == Language.SQL
assert NotebookLoaderForTesting.detect_language(Path("hi"), "# Databricks notebook source") == Language.PYTHON
assert NotebookLoaderForTesting.detect_language(Path("hi"), "-- Databricks notebook source") == Language.SQL
assert not NotebookLoaderForTesting.detect_language(Path("hi"), "stuff")
Loading

0 comments on commit 879a5b4

Please sign in to comment.