Skip to content

Commit

Permalink
fix: iterator modification analysis (#3764)
Browse files Browse the repository at this point in the history
this commit fixes several bugs with analysis of iterator modification in
loops. to do so, it refactors the analysis code to track reads/writes
more accurately, and uses analysis machinery instead of AST queries to
perform the check. it enriches ExprInfo with an `attr` attribute, so
this can be used to detect if an ExprInfo is derived from an
`Attribute`.

ExprInfo could be further enriched with `Subscript` info
so that the Attribute/Subscript chain can be reliably recovered just
from ExprInfos, especially in the future if other functions rely on
being able to recover the attribute chain.

this commit also modifies `validate_functions` so that it validates the
functions in dependency (call graph traversal) order rather than the
order they appear in the AST.

refactors:
- add `enter_for_loop()` context manager for convenience+clarity
- remove `ExprInfo.attribute_chain`, it was too confusing
- hide `ContractFunctionT` member variables (`_variable_reads`,
  `_variable_writes`, `_used_modules`) behind public-facing API
- remove `get_root_varinfo()` in favor of a helper
  `_get_variable_access()` function which detects access on variable
  sub-members (e.g., structs).
  • Loading branch information
charles-cooper authored Feb 13, 2024
1 parent a3bc3eb commit 7bdebbf
Show file tree
Hide file tree
Showing 13 changed files with 505 additions and 216 deletions.
56 changes: 55 additions & 1 deletion tests/functional/codegen/features/iteration/test_for_in_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest

from vyper.compiler import compile_code
from vyper.exceptions import (
ArgumentException,
ImmutableViolation,
Expand Down Expand Up @@ -841,6 +842,59 @@ def foo():
]


# TODO: move these to tests/functional/syntax
@pytest.mark.parametrize("code,err", BAD_CODE, ids=bad_code_names)
def test_bad_code(assert_compile_failed, get_contract, code, err):
assert_compile_failed(lambda: get_contract(code), err)
with pytest.raises(err):
compile_code(code)


def test_iterator_modification_module_attribute(make_input_bundle):
# test modifying iterator via attribute
lib1 = """
queue: DynArray[uint256, 5]
"""
main = """
import lib1
initializes: lib1
@external
def foo():
for i: uint256 in lib1.queue:
lib1.queue.pop()
"""

input_bundle = make_input_bundle({"lib1.vy": lib1})

with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot modify loop variable `queue`"


def test_iterator_modification_module_function_call(make_input_bundle):
lib1 = """
queue: DynArray[uint256, 5]
@internal
def popqueue():
self.queue.pop()
"""
main = """
import lib1
initializes: lib1
@external
def foo():
for i: uint256 in lib1.queue:
lib1.popqueue()
"""

input_bundle = make_input_bundle({"lib1.vy": lib1})

with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot modify loop variable `queue`"
42 changes: 42 additions & 0 deletions tests/functional/syntax/modules/test_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,48 @@ def foo(new_value: uint256):
assert e.value._hint == expected_hint


def test_missing_uses_subscript(make_input_bundle):
# test missing uses through nested subscript/attribute access
lib1 = """
struct Foo:
array: uint256[5]
foos: Foo[5]
"""
lib2 = """
import lib1
counter: uint256
@internal
def foo():
pass
"""
main = """
import lib1
import lib2
initializes: lib1
# did not `use` or `initialize` lib2!
@external
def foo(new_value: uint256):
# cannot access lib1 state through lib2
lib2.lib1.foos[0].array[1] = new_value
"""
input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2})

with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib2` state!"

expected_hint = "add `uses: lib2` or `initializes: lib2` as a "
expected_hint += "top-level statement to your contract"
assert e.value._hint == expected_hint


def test_missing_uses_nested_attribute_function_call(make_input_bundle):
# test missing uses through nested attribute access
lib1 = """
Expand Down
105 changes: 105 additions & 0 deletions tests/unit/semantics/analysis/test_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,111 @@ def baz():
validate_semantics(vyper_module, dummy_input_bundle)


def test_modify_iterator_recursive_function_call_topsort(dummy_input_bundle):
# test the analysis works no matter the order of functions
code = """
a: uint256[3]
@internal
def baz():
for i: uint256 in self.a:
self.bar()
@internal
def bar():
self.foo()
@internal
def foo():
self.a[0] = 1
"""
vyper_module = parse_to_ast(code)
with pytest.raises(ImmutableViolation) as e:
validate_semantics(vyper_module, dummy_input_bundle)

assert e.value._message == "Cannot modify loop variable `a`"


def test_modify_iterator_through_struct(dummy_input_bundle):
# GH issue 3429
code = """
struct A:
iter: DynArray[uint256, 5]
a: A
@external
def foo():
self.a.iter = [1, 2, 3]
for i: uint256 in self.a.iter:
self.a = A({iter: [1, 2, 3, 4]})
"""
vyper_module = parse_to_ast(code)
with pytest.raises(ImmutableViolation) as e:
validate_semantics(vyper_module, dummy_input_bundle)

assert e.value._message == "Cannot modify loop variable `a`"


def test_modify_iterator_complex_expr(dummy_input_bundle):
# GH issue 3429
# avoid false positive!
code = """
a: DynArray[uint256, 5]
b: uint256[10]
@external
def foo():
self.a = [1, 2, 3]
for i: uint256 in self.a:
self.b[self.a[1]] = i
"""
vyper_module = parse_to_ast(code)
validate_semantics(vyper_module, dummy_input_bundle)


def test_modify_iterator_siblings(dummy_input_bundle):
# test we can modify siblings in an access tree
code = """
struct Foo:
a: uint256[2]
b: uint256
f: Foo
@external
def foo():
for i: uint256 in self.f.a:
self.f.b += i
"""
vyper_module = parse_to_ast(code)
validate_semantics(vyper_module, dummy_input_bundle)


def test_modify_subscript_barrier(dummy_input_bundle):
# test that Subscript nodes are a barrier for analysis
code = """
struct Foo:
x: uint256[2]
y: uint256
struct Bar:
f: Foo[2]
b: Bar
@external
def foo():
for i: uint256 in self.b.f[1].x:
self.b.f[0].y += i
"""
vyper_module = parse_to_ast(code)
with pytest.raises(ImmutableViolation) as e:
validate_semantics(vyper_module, dummy_input_bundle)

assert e.value._message == "Cannot modify loop variable `b`"


iterator_inference_codes = [
"""
@external
Expand Down
8 changes: 4 additions & 4 deletions vyper/ast/nodes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,13 @@ class Call(ExprNode):

class keyword(VyperNode): ...

class Attribute(VyperNode):
class Attribute(ExprNode):
attr: str = ...
value: ExprNode = ...

class Subscript(VyperNode):
slice: VyperNode = ...
value: VyperNode = ...
class Subscript(ExprNode):
slice: ExprNode = ...
value: ExprNode = ...

class Assign(VyperNode): ...

Expand Down
58 changes: 30 additions & 28 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,24 +263,6 @@ def parse_Attribute(self):
if addr.value == "address": # for `self.code`
return IRnode.from_list(["~selfcode"], typ=BytesT(0))
return IRnode.from_list(["~extcode", addr], typ=BytesT(0))
# self.x: global attribute
elif (varinfo := self.expr._expr_info.var_info) is not None:
if varinfo.is_constant:
return Expr.parse_value_expr(varinfo.decl_node.value, self.context)

location = data_location_to_address_space(
varinfo.location, self.context.is_ctor_context
)

ret = IRnode.from_list(
varinfo.position.position,
typ=varinfo.typ,
location=location,
annotation="self." + self.expr.attr,
)
ret._referenced_variables = {varinfo}

return ret

# Reserved keywords
elif (
Expand Down Expand Up @@ -336,17 +318,37 @@ def parse_Attribute(self):
"chain.id is unavailable prior to istanbul ruleset", self.expr
)
return IRnode.from_list(["chainid"], typ=UINT256_T)

# Other variables
else:
sub = Expr(self.expr.value, self.context).ir_node
# contract type
if isinstance(sub.typ, InterfaceT):
# MyInterface.address
assert self.expr.attr == "address"
sub.typ = typ
return sub
if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types:
return get_element_ptr(sub, self.expr.attr)

# self.x: global attribute
if (varinfo := self.expr._expr_info.var_info) is not None:
if varinfo.is_constant:
return Expr.parse_value_expr(varinfo.decl_node.value, self.context)

location = data_location_to_address_space(
varinfo.location, self.context.is_ctor_context
)

ret = IRnode.from_list(
varinfo.position.position,
typ=varinfo.typ,
location=location,
annotation="self." + self.expr.attr,
)
ret._referenced_variables = {varinfo}

return ret

sub = Expr(self.expr.value, self.context).ir_node
# contract type
if isinstance(sub.typ, InterfaceT):
# MyInterface.address
assert self.expr.attr == "address"
sub.typ = typ
return sub
if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types:
return get_element_ptr(sub, self.expr.attr)

def parse_Subscript(self):
sub = Expr(self.expr.value, self.context).ir_node
Expand Down
Loading

0 comments on commit 7bdebbf

Please sign in to comment.