Skip to content

Commit

Permalink
fix[lang]: fix array index checks when the subscript is folded (vyper…
Browse files Browse the repository at this point in the history
…lang#3924)

this commit fixes a regression introduced in 56c4c9d. prior to
56c4c9d, the folded subscript would be checked for OOB access, but
after 56c4c9d, expressions like `foo[0 - 1]` can slip past the
typechecker (getting demoted to a runtime check). also, a common pattern
is refactored.

common pattern:
```python
if node.has_folded_value:
    node = node.get_folded_value()
```
=>
```
node = node.reduced()
```
  • Loading branch information
charles-cooper authored and electriclilies committed Apr 27, 2024
1 parent 2965d25 commit 7db4f87
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 15 deletions.
22 changes: 22 additions & 0 deletions tests/unit/ast/nodes/test_fold_subscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from hypothesis import strategies as st

from tests.utils import parse_and_fold
from vyper.compiler import compile_code
from vyper.exceptions import ArrayIndexException


@pytest.mark.fuzzing
Expand All @@ -24,3 +26,23 @@ def foo(array: int128[10], idx: uint256) -> int128:
new_node = old_node.get_folded_value()

assert contract.foo(array, idx) == new_node.value


def test_negative_index():
source = """
@external
def foo(array: int128[10]) -> int128:
return array[0 - 1]
"""
with pytest.raises(ArrayIndexException):
compile_code(source)


def test_oob_index():
source = """
@external
def foo(array: int128[10]) -> int128:
return array[9 + 1]
"""
with pytest.raises(ArrayIndexException):
compile_code(source)
5 changes: 5 additions & 0 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,11 @@ def get_folded_value(self) -> "ExprNode":
except KeyError:
raise UnfoldableNode("not foldable", self)

def reduced(self) -> "ExprNode":
if self.has_folded_value:
return self.get_folded_value()
return self

def _set_folded_value(self, node: "VyperNode") -> None:
# sanity check this is only called once
assert "folded_value" not in self._metadata
Expand Down
1 change: 1 addition & 0 deletions vyper/ast/nodes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class VyperNode:
def get_fields(cls: Any) -> set: ...
def set_parent(self, parent: VyperNode) -> VyperNode: ...
def get_folded_value(self) -> ExprNode: ...
def reduced(self) -> ExprNode: ...
def _set_folded_value(self, node: ExprNode) -> None: ...
@classmethod
def from_node(cls, node: VyperNode, **kwargs: Any) -> Any: ...
Expand Down
7 changes: 4 additions & 3 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ class Expr:

def __init__(self, node, context, is_stmt=False):
assert isinstance(node, vy_ast.VyperNode)
if node.has_folded_value:
node = node.get_folded_value()
node = node.reduced()

self.expr = node
self.context = context
Expand Down Expand Up @@ -347,7 +346,9 @@ def parse_Subscript(self):
index = Expr.parse_value_expr(self.expr.slice, self.context)

elif is_tuple_like(sub.typ):
index = self.expr.slice.n
# should we annotate expr.slice in the frontend with the
# folded value instead of calling reduced() here?
index = self.expr.slice.reduced().n
# note: this check should also happen in get_element_ptr
if not 0 <= index < len(sub.typ.member_types):
raise TypeCheckFailure("unreachable")
Expand Down
10 changes: 3 additions & 7 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,9 +527,7 @@ def _analyse_range_iter(self, iter_node, target_type):

def _analyse_list_iter(self, iter_node, target_type):
# iteration over a variable or literal list
iter_val = iter_node
if iter_val.has_folded_value:
iter_val = iter_val.get_folded_value()
iter_val = iter_node.reduced()

if isinstance(iter_val, vy_ast.List):
len_ = len(iter_val.elements)
Expand Down Expand Up @@ -946,12 +944,10 @@ def _validate_range_call(node: vy_ast.Call):
validate_call_args(node, (1, 2), kwargs=["bound"])
kwargs = {s.arg: s.value for s in node.keywords or []}
start, end = (vy_ast.Int(value=0), node.args[0]) if len(node.args) == 1 else node.args
start, end = [i.get_folded_value() if i.has_folded_value else i for i in (start, end)]
start, end = [i.reduced() for i in (start, end)]

if "bound" in kwargs:
bound = kwargs["bound"]
if bound.has_folded_value:
bound = bound.get_folded_value()
bound = kwargs["bound"].reduced()
if not isinstance(bound, vy_ast.Int):
raise StructureException("Bound must be a literal integer", bound)
if bound.value <= 0:
Expand Down
9 changes: 6 additions & 3 deletions vyper/semantics/types/subscriptable.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def validate_index_type(self, node):
# TODO break this cycle
from vyper.semantics.analysis.utils import validate_expected_type

node = node.reduced()

if isinstance(node, vy_ast.Int):
if node.value < 0:
raise ArrayIndexException("Vyper does not support negative indexing", node)
Expand Down Expand Up @@ -290,9 +292,7 @@ def from_annotation(cls, node: vy_ast.Subscript) -> "DArrayT":
if not isinstance(node.slice, vy_ast.Tuple) or len(node.slice.elements) != 2:
raise StructureException(err_msg, node.slice)

length_node = node.slice.elements[1]
if length_node.has_folded_value:
length_node = length_node.get_folded_value()
length_node = node.slice.elements[1].reduced()

if not isinstance(length_node, vy_ast.Int):
raise StructureException(err_msg, length_node)
Expand Down Expand Up @@ -367,6 +367,8 @@ def size_in_bytes(self):
return sum(i.size_in_bytes for i in self.member_types)

def validate_index_type(self, node):
node = node.reduced()

if not isinstance(node, vy_ast.Int):
raise InvalidType("Tuple indexes must be literals", node)
if node.value < 0:
Expand All @@ -375,6 +377,7 @@ def validate_index_type(self, node):
raise ArrayIndexException("Index out of range", node)

def get_subscripted_type(self, node):
node = node.reduced()
return self.member_types[node.value]

def compare_type(self, other):
Expand Down
3 changes: 1 addition & 2 deletions vyper/semantics/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,7 @@ def get_index_value(node: vy_ast.VyperNode) -> int:
# TODO: revisit this!
from vyper.semantics.analysis.utils import get_possible_types_from_node

if node.has_folded_value:
node = node.get_folded_value()
node = node.reduced()

if not isinstance(node, vy_ast.Int):
# even though the subscript is an invalid type, first check if it's a valid _something_
Expand Down

0 comments on commit 7db4f87

Please sign in to comment.