Skip to content
/ dvc Public
forked from iterative/dvc

Commit

Permalink
serialize: py: try to track segments of the source
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed May 12, 2021
1 parent 0ca08d2 commit 9aa008e
Showing 1 changed file with 32 additions and 12 deletions.
44 changes: 32 additions & 12 deletions dvc/utils/serialize/_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def parse_py(text, path):
with reraise(SyntaxError, PythonFileCorruptedError(path)):
tree = ast.parse(text, filename=path)

result = _ast_tree_to_dict(tree)
lines = text.splitlines()
result = _ast_tree_to_dict(tree, lines)
return result


Expand All @@ -32,8 +33,9 @@ def parse_py_for_update(text, path):
with reraise(SyntaxError, PythonFileCorruptedError(path)):
tree = ast.parse(text, filename=path)

result = _ast_tree_to_dict(tree)
result.update({_PARAMS_KEY: _ast_tree_to_dict(tree, lineno=True)})
lines = text.splitlines()
result = _ast_tree_to_dict(tree, lines)
result.update({_PARAMS_KEY: _ast_tree_to_dict(tree, lines, lineno=True)})
result.update({_PARAMS_TEXT_KEY: text})
return result

Expand All @@ -53,10 +55,13 @@ def _update_lines(lines, old_dct, new_dct):
if isinstance(value, dict):
lines = _update_lines(lines, old_dct[key], value)
elif value != old_dct[key]["value"]:
old_value = old_dct[key]["value"]
lineno = old_dct[key]["lineno"]
lines[lineno] = lines[lineno].replace(
f" = {old_dct[key]['value']}", f" = {value}"
)

segment = old_dct[key].get("segment")
old_segment = " = {}".format(segment or old_value)
new_segment = " = {}".format(value)
lines[lineno] = lines[lineno].replace(old_segment, new_segment)
else:
continue
return lines
Expand Down Expand Up @@ -86,7 +91,7 @@ def modify_py(path, fs=None):
yield d


def _ast_tree_to_dict(tree, only_self_params=False, lineno=False):
def _ast_tree_to_dict(tree, src_lines, only_self_params=False, lineno=False):
"""Parses ast trees to dict.
:param tree: ast.Tree
Expand All @@ -99,18 +104,24 @@ def _ast_tree_to_dict(tree, only_self_params=False, lineno=False):
try:
if isinstance(_body, (ast.Assign, ast.AnnAssign)):
result.update(
_ast_assign_to_dict(_body, only_self_params, lineno)
_ast_assign_to_dict(
_body, src_lines, only_self_params, lineno
)
)
elif isinstance(_body, ast.ClassDef):
result.update(
{_body.name: _ast_tree_to_dict(_body, lineno=lineno)}
{
_body.name: _ast_tree_to_dict(
_body, src_lines, lineno=lineno
)
}
)
elif (
isinstance(_body, ast.FunctionDef) and _body.name == "__init__"
):
result.update(
_ast_tree_to_dict(
_body, only_self_params=True, lineno=lineno
_body, src_lines, only_self_params=True, lineno=lineno
)
)
except ValueError:
Expand All @@ -120,7 +131,9 @@ def _ast_tree_to_dict(tree, only_self_params=False, lineno=False):
return result


def _ast_assign_to_dict(assign, only_self_params=False, lineno=False):
def _ast_assign_to_dict(
assign, src_lines, only_self_params=False, lineno=False
):
result = {}

if isinstance(assign, ast.AnnAssign):
Expand Down Expand Up @@ -152,7 +165,14 @@ def _ast_assign_to_dict(assign, only_self_params=False, lineno=False):
value = _get_ast_value(assign.value)

if lineno and not isinstance(assign.value, ast.Dict):
result[name] = {"lineno": assign.lineno - 1, "value": value}
v = assign.value
offsets = slice(v.col_offset, v.end_col_offset)
lno = assign.lineno - 1
result[name] = {
"lineno": lno,
"value": value,
"segment": src_lines[lno][offsets],
}
else:
result[name] = value

Expand Down

0 comments on commit 9aa008e

Please sign in to comment.