Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added more value inference for dbutils.notebook.run(...) #1860

Merged
merged 19 commits into from
Jun 8, 2024
Merged
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions src/databricks/labs/ucx/source_code/graph.py
Original file line number Diff line number Diff line change
@@ -191,19 +191,27 @@ def build_graph_from_python_source(self, python_code: str) -> list[DependencyPro
def _process_node(self, base_node: NodeBase):
if isinstance(base_node, SysPathChange):
self._mutate_path_lookup(base_node)
return
if isinstance(base_node, NotebookRunCall):
strpath = base_node.get_notebook_path()
if strpath is None:
yield DependencyProblem('dependency-not-constant', "Can't check dependency not provided as a constant")
else:
yield from self.register_notebook(Path(strpath))
yield from self._register_notebook(base_node)
return
if isinstance(base_node, ImportSource):
prefix = ""
if isinstance(base_node.node, ImportFrom) and base_node.node.level is not None:
prefix = "." * base_node.node.level
name = base_node.name or ""
yield from self.register_import(prefix + name)

def _register_notebook(self, base_node: NotebookRunCall):
has_unresolved, paths = base_node.get_notebook_paths()
if has_unresolved:
yield DependencyProblem(
'dependency-cannot-compute',
f"Can't check dependency from {base_node.node.as_string()} because the expression cannot be computed",
)
for path in paths:
yield from self.register_notebook(Path(path))

def _mutate_path_lookup(self, change: SysPathChange):
path = Path(change.path)
if not path.is_absolute():
58 changes: 38 additions & 20 deletions src/databricks/labs/ucx/source_code/linters/imports.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
Attribute,
Call,
Const,
InferenceError,
Import,
ImportFrom,
Name,
@@ -81,36 +82,53 @@ class NotebookRunCall(NodeBase):
def __init__(self, node: Call):
super().__init__(node)

def get_notebook_path(self) -> str | None:
node = DbutilsLinter.get_dbutils_notebook_run_path_arg(cast(Call, self.node))
inferred = next(node.infer(), None)
if isinstance(inferred, Const):
return inferred.value.strip().strip("'").strip('"')
return None
def get_notebook_paths(self) -> tuple[bool, list[str]]:
"""we return multiple paths because astroid can infer them in scenarios such as:
paths = ["p1", "p2"]
for path in paths:
dbutils.notebook.run(path)
"""
node = DbutilsLinter.get_dbutils_notebook_run_path_arg(self.node)
try:
return self._get_notebook_paths(node.infer())
except InferenceError:
logger.debug(f"Can't infer value(s) of {node.as_string()}")
return True, []

@classmethod
def _get_notebook_paths(cls, nodes: Iterable[NodeNG]) -> tuple[bool, list[str]]:
has_unresolved = False
paths: list[str] = []
for node in nodes:
if isinstance(node, Const):
paths.append(node.as_string().strip("'").strip('"'))
continue
logger.debug(f"Can't compute {type(node).__name__}")
has_unresolved = True
return has_unresolved, paths


class DbutilsLinter(Linter):

def lint(self, code: str) -> Iterable[Advice]:
tree = Tree.parse(code)
nodes = self.list_dbutils_notebook_run_calls(tree)
return [self._convert_dbutils_notebook_run_to_advice(node.node) for node in nodes]
for node in nodes:
yield from self._raise_advice_if_unresolved(node.node)

@classmethod
def _convert_dbutils_notebook_run_to_advice(cls, node: NodeNG) -> Advisory:
def _raise_advice_if_unresolved(cls, node: NodeNG) -> Iterable[Advice]:
assert isinstance(node, Call)
path = cls.get_dbutils_notebook_run_path_arg(node)
if isinstance(path, Const):
return Advisory.from_node(
'dbutils-notebook-run-literal',
"Call to 'dbutils.notebook.run' will be migrated automatically",
node=node,
)
return Advisory.from_node(
'dbutils-notebook-run-dynamic',
"Path for 'dbutils.notebook.run' is not a constant and requires adjusting the notebook path",
node=node,
)
call = NotebookRunCall(cast(Call, node))
has_unresolved, _ = call.get_notebook_paths()
if has_unresolved:
yield from [
Advisory.from_node(
'dbutils-notebook-run-dynamic',
"Path for 'dbutils.notebook.run' cannot be computed and requires adjusting the notebook path(s)",
node=node,
)
]

@staticmethod
def get_dbutils_notebook_run_path_arg(node: Call):
15 changes: 9 additions & 6 deletions src/databricks/labs/ucx/source_code/notebooks/loaders.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
SourceContainer,
)
from databricks.labs.ucx.source_code.notebooks.cells import CellLanguage
from databricks.labs.ucx.source_code.notebooks.sources import Notebook
from databricks.labs.ucx.source_code.notebooks.sources import Notebook, SUPPORTED_EXTENSION_LANGUAGES
from databricks.labs.ucx.source_code.path_lookup import PathLookup

logger = logging.getLogger(__name__)
@@ -55,17 +55,20 @@ def load_dependency(self, path_lookup: PathLookup, dependency: Dependency) -> So
except NotFound:
logger.warning(f"Could not read notebook from workspace: {absolute_path}")
return None
language = self._detect_language(content)
language = self._detect_language(absolute_path, content)
if not language:
logger.warning(f"Could not detect language for {absolute_path}")
return None
return Notebook.parse(absolute_path, content, language)

@staticmethod
def _detect_language(content: str):
for language in CellLanguage:
if content.startswith(language.file_magic_header):
return language.language
def _detect_language(path: Path, content: str):
language = SUPPORTED_EXTENSION_LANGUAGES.get(path.suffix, None)
if language:
return language
for cell_language in CellLanguage:
if content.startswith(cell_language.file_magic_header):
return cell_language.language
return None

@staticmethod
85 changes: 85 additions & 0 deletions tests/unit/source_code/linters/test_python_ast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import pytest
from astroid import Attribute, Call, Const, Expr # type: ignore

from databricks.labs.ucx.source_code.linters.python_ast import Tree


def test_extract_call_by_name():
tree = Tree.parse("o.m1().m2().m3()")
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.extract_call_by_name(stmt.value, "m2")
assert isinstance(act, Call)
assert isinstance(act.func, Attribute)
assert act.func.attrname == "m2"


def test_extract_call_by_name_none():
tree = Tree.parse("o.m1().m2().m3()")
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.extract_call_by_name(stmt.value, "m5000")
assert act is None


@pytest.mark.parametrize(
"code, arg_index, arg_name, expected",
[
("o.m1()", 1, "second", None),
("o.m1(3)", 1, "second", None),
("o.m1(first=3)", 1, "second", None),
("o.m1(4, 3)", None, None, None),
("o.m1(4, 3)", None, "second", None),
("o.m1(4, 3)", 1, "second", 3),
("o.m1(4, 3)", 1, None, 3),
("o.m1(first=4, second=3)", 1, "second", 3),
("o.m1(second=3, first=4)", 1, "second", 3),
("o.m1(second=3, first=4)", None, "second", 3),
("o.m1(second=3)", 1, "second", 3),
("o.m1(4, 3, 2)", 1, "second", 3),
],
)
def test_linter_gets_arg(code, arg_index, arg_name, expected):
tree = Tree.parse(code)
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.get_arg(stmt.value, arg_index, arg_name)
if expected is None:
assert act is None
else:
assert isinstance(act, Const)
assert act.value == expected


@pytest.mark.parametrize(
"code, expected",
[
("o.m1()", 0),
("o.m1(3)", 1),
("o.m1(first=3)", 1),
("o.m1(3, 3)", 2),
("o.m1(first=3, second=3)", 2),
("o.m1(3, second=3)", 2),
("o.m1(3, *b, **c, second=3)", 4),
],
)
def test_args_count(code, expected):
tree = Tree.parse(code)
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.args_count(stmt.value)
assert act == expected


def test_tree_walks_nodes_once():
nodes = set()
count = 0
tree = Tree.parse("o.m1().m2().m3()")
for node in tree.walk():
nodes.add(node)
count += 1
assert len(nodes) == count
119 changes: 32 additions & 87 deletions tests/unit/source_code/linters/test_python_imports.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@


import pytest
from astroid import Attribute, Call, Const, Expr # type: ignore

from databricks.labs.ucx.source_code.graph import DependencyProblem

from databricks.labs.ucx.source_code.linters.imports import DbutilsLinter, ImportSource, SysPathChange
@@ -137,77 +137,6 @@ def test_linter_returns_appended_relative_paths_with_os_path_abspath_alias():
assert "relative_path" in [p.path for p in appended]


def test_extract_call_by_name():
tree = Tree.parse("o.m1().m2().m3()")
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.extract_call_by_name(stmt.value, "m2")
assert isinstance(act, Call)
assert isinstance(act.func, Attribute)
assert act.func.attrname == "m2"


def test_extract_call_by_name_none():
tree = Tree.parse("o.m1().m2().m3()")
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.extract_call_by_name(stmt.value, "m5000")
assert act is None


@pytest.mark.parametrize(
"code, arg_index, arg_name, expected",
[
("o.m1()", 1, "second", None),
("o.m1(3)", 1, "second", None),
("o.m1(first=3)", 1, "second", None),
("o.m1(4, 3)", None, None, None),
("o.m1(4, 3)", None, "second", None),
("o.m1(4, 3)", 1, "second", 3),
("o.m1(4, 3)", 1, None, 3),
("o.m1(first=4, second=3)", 1, "second", 3),
("o.m1(second=3, first=4)", 1, "second", 3),
("o.m1(second=3, first=4)", None, "second", 3),
("o.m1(second=3)", 1, "second", 3),
("o.m1(4, 3, 2)", 1, "second", 3),
],
)
def test_linter_gets_arg(code, arg_index, arg_name, expected):
tree = Tree.parse(code)
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.get_arg(stmt.value, arg_index, arg_name)
if expected is None:
assert act is None
else:
assert isinstance(act, Const)
assert act.value == expected


@pytest.mark.parametrize(
"code, expected",
[
("o.m1()", 0),
("o.m1(3)", 1),
("o.m1(first=3)", 1),
("o.m1(3, 3)", 2),
("o.m1(first=3, second=3)", 2),
("o.m1(3, second=3)", 2),
("o.m1(3, *b, **c, second=3)", 4),
],
)
def test_args_count(code, expected):
tree = Tree.parse(code)
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.args_count(stmt.value)
assert act == expected


@pytest.mark.parametrize(
"code, expected",
[
@@ -216,22 +145,38 @@ def test_args_count(code, expected):
name = "xyz"
dbutils.notebook.run(name)
""",
"xyz",
)
["xyz"],
),
(
"""
name = "xyz" + "-" + "abc"
dbutils.notebook.run(name)
""",
["xyz-abc"],
),
(
"""
names = ["abc", "xyz"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we already do

def foo(): return "bar"
name = foo()
dbutils.notebook.run(name)

or does it require building a small type-aware interpreter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can! added corresponding test in test_infers_dbutils_notebook_run_dynamic_value

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

for name in names:
dbutils.notebook.run(name)
""",
["abc", "xyz"],
),
(
"""
def foo(): return "bar"
name = foo()
dbutils.notebook.run(name)
""",
["bar"],
),
],
)
def test_infers_string_variable_value(code, expected):
def test_infers_dbutils_notebook_run_dynamic_value(code, expected):
tree = Tree.parse(code)
calls = DbutilsLinter.list_dbutils_notebook_run_calls(tree)
actual = list(call.get_notebook_path() for call in calls)
assert [expected] == actual


def test_tree_walker_walks_nodes_once():
nodes = set()
count = 0
tree = Tree.parse("o.m1().m2().m3()")
for node in tree.walk():
nodes.add(node)
count += 1
assert len(nodes) == count
all_paths: list[str] = []
for call in calls:
_, paths = call.get_notebook_paths()
all_paths.extend(paths)
assert all_paths == expected
19 changes: 19 additions & 0 deletions tests/unit/source_code/notebooks/test_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pathlib import Path

from databricks.labs.ucx.source_code.notebooks.loaders import NotebookLoader
from databricks.sdk.service.workspace import Language


def test_detects_language():

class NotebookLoaderForTesting(NotebookLoader):

@classmethod
def detect_language(cls, path: Path, content: str):
return cls._detect_language(path, content)

assert NotebookLoaderForTesting.detect_language(Path("hi.py"), "stuff") == Language.PYTHON
assert NotebookLoaderForTesting.detect_language(Path("hi.sql"), "stuff") == Language.SQL
assert NotebookLoaderForTesting.detect_language(Path("hi"), "# Databricks notebook source") == Language.PYTHON
assert NotebookLoaderForTesting.detect_language(Path("hi"), "-- Databricks notebook source") == Language.SQL
assert not NotebookLoaderForTesting.detect_language(Path("hi"), "stuff")
Loading