Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: levenshtein optimization #3780

Merged
merged 5 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/functional/syntax/modules/test_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,4 +1178,4 @@ def test_ownership_decl_errors_not_swallowed(make_input_bundle):
input_bundle = make_input_bundle({"lib1.vy": lib1})
with pytest.raises(UndeclaredDefinition) as e:
compile_code(main, input_bundle=input_bundle)
assert e.value._message == "'lib2' has not been declared. "
assert e.value._message == "'lib2' has not been declared."
39 changes: 32 additions & 7 deletions tests/functional/syntax/test_for_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def foo():
""",
StructureException,
"Invalid syntax for loop iterator",
None,
"a[1]",
),
(
Expand All @@ -32,6 +33,7 @@ def bar():
""",
StructureException,
"Bound must be at least 1",
None,
"0",
),
(
Expand All @@ -44,6 +46,7 @@ def foo():
""",
StateAccessViolation,
"Bound must be a literal",
None,
"x",
),
(
Expand All @@ -55,6 +58,7 @@ def foo():
""",
StructureException,
"Please remove the `bound=` kwarg when using range with constants",
None,
"5",
),
(
Expand All @@ -66,6 +70,7 @@ def foo():
""",
StructureException,
"Bound must be at least 1",
None,
"0",
),
(
Expand All @@ -78,6 +83,7 @@ def bar():
""",
ArgumentException,
"Invalid keyword argument 'extra'",
None,
"extra=3",
),
(
Expand All @@ -89,6 +95,7 @@ def bar():
""",
StructureException,
"End must be greater than start",
None,
"0",
),
(
Expand All @@ -101,6 +108,7 @@ def bar():
""",
StateAccessViolation,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
),
(
Expand All @@ -113,6 +121,7 @@ def bar():
""",
StateAccessViolation,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
),
(
Expand All @@ -125,6 +134,7 @@ def repeat(n: uint256) -> uint256:
""",
StateAccessViolation,
"Value must be a literal integer, unless a bound is specified",
None,
"n * 10",
),
(
Expand All @@ -137,6 +147,7 @@ def bar():
""",
StateAccessViolation,
"Value must be a literal integer, unless a bound is specified",
None,
"x + 1",
),
(
Expand All @@ -148,6 +159,7 @@ def bar():
""",
StructureException,
"End must be greater than start",
None,
"1",
),
(
Expand All @@ -160,6 +172,7 @@ def bar():
""",
StateAccessViolation,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
),
(
Expand All @@ -172,6 +185,7 @@ def foo():
""",
StateAccessViolation,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
),
(
Expand All @@ -184,6 +198,7 @@ def repeat(n: uint256) -> uint256:
""",
StateAccessViolation,
"Value must be a literal integer, unless a bound is specified",
None,
"n",
),
(
Expand All @@ -196,6 +211,7 @@ def foo(x: int128):
""",
StateAccessViolation,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
),
(
Expand All @@ -207,6 +223,7 @@ def bar(x: uint256):
""",
StateAccessViolation,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
),
(
Expand All @@ -221,6 +238,7 @@ def foo():
""",
TypeMismatch,
"Given reference has type int128, expected uint256",
None,
"FOO",
),
(
Expand All @@ -234,6 +252,7 @@ def foo():
""",
StructureException,
"Bound must be at least 1",
None,
"FOO",
),
(
Expand All @@ -244,7 +263,8 @@ def foo():
pass
""",
UnknownType,
"No builtin or user-defined type named 'DynArra'. Did you mean 'DynArray'?",
"No builtin or user-defined type named 'DynArra'.",
"Did you mean 'DynArray'?",
"DynArra",
),
(
Expand All @@ -262,7 +282,8 @@ def foo():
pass
""",
UnknownType,
"No builtin or user-defined type named 'uint9'. Did you mean 'uint96', or maybe 'uint8'?",
"No builtin or user-defined type named 'uint9'.",
"Did you mean 'uint96', or maybe 'uint8'?",
"uint9",
),
(
Expand All @@ -278,7 +299,8 @@ def foo():
pass
""",
UnknownType,
"No builtin or user-defined type named 'uint9'. Did you mean 'uint96', or maybe 'uint8'?",
"No builtin or user-defined type named 'uint9'.",
"Did you mean 'uint96', or maybe 'uint8'?",
"uint9",
),
]
Expand All @@ -289,15 +311,18 @@ def foo():
f"{i:02d}: {for_code_regex.search(code).group(1)}" # type: ignore[union-attr]
f" raises {type(err).__name__}"
)
for i, (code, err, msg, src) in enumerate(fail_list)
for i, (code, err, msg, hint, src) in enumerate(fail_list)
]


@pytest.mark.parametrize("bad_code,error_type,message,source_code", fail_list, ids=fail_test_names)
def test_range_fail(bad_code, error_type, message, source_code):
@pytest.mark.parametrize(
"bad_code,error_type,message,hint,source_code", fail_list, ids=fail_test_names
)
def test_range_fail(bad_code, error_type, message, hint, source_code):
with pytest.raises(error_type) as exc_info:
compiler.compile_code(bad_code)
assert message == exc_info.value.message
assert message == exc_info.value._message
assert hint == exc_info.value.hint
assert source_code == exc_info.value.args[1].get_original_node().node_source_code


Expand Down
13 changes: 11 additions & 2 deletions vyper/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,20 @@ def with_annotation(self, *annotations):
exc.annotations = annotations
return exc

@property
def hint(self):
# some hints are expensive to compute, so we wait until the last
# minute when the formatted message is actually requested to compute
# them.
if callable(self._hint):
return self._hint()
return self._hint

@property
def message(self):
msg = self._message
if self._hint:
msg += f"\n\n (hint: {self._hint})"
if self.hint:
msg += f"\n\n (hint: {self.hint})"
return msg

def __str__(self):
Expand Down
10 changes: 8 additions & 2 deletions vyper/semantics/analysis/levenshtein_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Callable


def levenshtein_norm(source: str, target: str) -> float:
Expand Down Expand Up @@ -73,7 +73,13 @@ def levenshtein(source: str, target: str) -> int:
return matrix[len(source)][len(target)]


def get_levenshtein_error_suggestions(key: str, namespace: Dict[str, Any], threshold: float) -> str:
def get_levenshtein_error_suggestions(*args, **kwargs) -> Callable:
return lambda: _get_levenshtein_error_suggestions(*args, **kwargs)


def _get_levenshtein_error_suggestions(
key: str, namespace: dict[str, Any], threshold: float
) -> str:
"""
Generate an error message snippet for the suggested closest values in the provided namespace
with the shortest normalized Levenshtein distance from the given key if that distance
Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ def _raise_invalid_reference(name, node):
if name in self.namespace:
_raise_invalid_reference(name, node)

suggestions_str = get_levenshtein_error_suggestions(name, t.members, 0.4)
hint = get_levenshtein_error_suggestions(name, t.members, 0.4)
raise UndeclaredDefinition(
f"Storage variable '{name}' has not been declared. {suggestions_str}", node
f"Storage variable '{name}' has not been declared.", node, hint=hint
) from None

def types_from_BinOp(self, node):
Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def __setitem__(self, attr, obj):

def __getitem__(self, key):
if key not in self:
suggestions_str = get_levenshtein_error_suggestions(key, self, 0.2)
raise UndeclaredDefinition(f"'{key}' has not been declared. {suggestions_str}")
hint = get_levenshtein_error_suggestions(key, self, 0.2)
raise UndeclaredDefinition(f"'{key}' has not been declared.", hint=hint)
return super().__getitem__(key)

def __enter__(self):
Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,8 @@ def get_member(self, key: str, node: vy_ast.VyperNode) -> "VyperType":
if not self.members:
raise StructureException(f"{self} instance does not have members", node)

suggestions_str = get_levenshtein_error_suggestions(key, self.members, 0.3)
raise UnknownAttribute(f"{self} has no member '{key}'. {suggestions_str}", node)
hint = get_levenshtein_error_suggestions(key, self.members, 0.3)
raise UnknownAttribute(f"{self} has no member '{key}'.", node, hint=hint)

def __repr__(self):
return self._id
Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/types/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,9 @@ def _ctor_call_return(self, node: vy_ast.Call) -> "StructT":
keys = list(self.member_types.keys())
for i, (key, value) in enumerate(zip(node.args[0].keys, node.args[0].values)):
if key is None or key.get("id") not in members:
suggestions_str = get_levenshtein_error_suggestions(key.get("id"), members, 1.0)
hint = get_levenshtein_error_suggestions(key.get("id"), members, 1.0)
raise UnknownAttribute(
f"Unknown or duplicate struct member. {suggestions_str}", key or value
"Unknown or duplicate struct member.", key or value, hint=hint
)
expected_key = keys[i]
if key.id != expected_key:
Expand Down
5 changes: 2 additions & 3 deletions vyper/semantics/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,9 @@ def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType:
raise InvalidType(err_msg, node)

if node.id not in namespace: # type: ignore
suggestions_str = get_levenshtein_error_suggestions(node.node_source_code, namespace, 0.3)
hint = get_levenshtein_error_suggestions(node.node_source_code, namespace, 0.3)
raise UnknownType(
f"No builtin or user-defined type named '{node.node_source_code}'. {suggestions_str}",
node,
f"No builtin or user-defined type named '{node.node_source_code}'.", node, hint=hint
) from None

typ_ = namespace[node.id]
Expand Down
Loading