Skip to content

Commit

Permalink
feat: implement bound= in ranges (#3537)
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper authored Jul 25, 2023
1 parent 019a37a commit d48438e
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 15 deletions.
35 changes: 34 additions & 1 deletion tests/functional/semantics/analysis/test_for_loop.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import pytest

from vyper.ast import parse_to_ast
from vyper.exceptions import ImmutableViolation, TypeMismatch
from vyper.exceptions import (
ArgumentException,
ImmutableViolation,
StateAccessViolation,
TypeMismatch,
)
from vyper.semantics.analysis import validate_semantics


Expand Down Expand Up @@ -59,6 +64,34 @@ def bar():
validate_semantics(vyper_module, {})


def test_bad_keywords(namespace):
code = """
@internal
def bar(n: uint256):
x: uint256 = 0
for i in range(n, boundddd=10):
x += i
"""
vyper_module = parse_to_ast(code)
with pytest.raises(ArgumentException):
validate_semantics(vyper_module, {})


def test_bad_bound(namespace):
code = """
@internal
def bar(n: uint256):
x: uint256 = 0
for i in range(n, bound=n):
x += i
"""
vyper_module = parse_to_ast(code)
with pytest.raises(StateAccessViolation):
validate_semantics(vyper_module, {})


def test_modify_iterator_function_call(namespace):
code = """
Expand Down
17 changes: 17 additions & 0 deletions tests/parser/features/iteration/test_for_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,23 @@ def repeat(z: int128) -> int128:
assert c.repeat(9) == 54


def test_range_bound(get_contract, assert_tx_failed):
code = """
@external
def repeat(n: uint256) -> uint256:
x: uint256 = 0
for i in range(n, bound=6):
x += i
return x
"""
c = get_contract(code)
for n in range(7):
assert c.repeat(n) == sum(range(n))

# check codegen inserts assertion for n greater than bound
assert_tx_failed(lambda: c.repeat(7))


def test_digit_reverser(get_contract_with_gas_estimation):
digit_reverser = """
@external
Expand Down
21 changes: 16 additions & 5 deletions vyper/codegen/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,18 +258,25 @@ def _parse_For_range(self):
arg0 = self.stmt.iter.args[0]
num_of_args = len(self.stmt.iter.args)

kwargs = {
s.arg: Expr.parse_value_expr(s.value, self.context)
for s in self.stmt.iter.keywords or []
}

# Type 1 for, e.g. for i in range(10): ...
if num_of_args == 1:
arg0_val = self._get_range_const_value(arg0)
n = Expr.parse_value_expr(arg0, self.context)
start = IRnode.from_list(0, typ=iter_typ)
rounds = arg0_val
rounds = n
rounds_bound = kwargs.get("bound", rounds)

# Type 2 for, e.g. for i in range(100, 110): ...
elif self._check_valid_range_constant(self.stmt.iter.args[1]).is_literal:
arg0_val = self._get_range_const_value(arg0)
arg1_val = self._get_range_const_value(self.stmt.iter.args[1])
start = IRnode.from_list(arg0_val, typ=iter_typ)
rounds = IRnode.from_list(arg1_val - arg0_val, typ=iter_typ)
rounds_bound = rounds

# Type 3 for, e.g. for i in range(x, x + 10): ...
else:
Expand All @@ -278,9 +285,10 @@ def _parse_For_range(self):
start = Expr.parse_value_expr(arg0, self.context)
_, hi = start.typ.int_bounds
start = clamp("le", start, hi + 1 - rounds)
rounds_bound = rounds

r = rounds if isinstance(rounds, int) else rounds.value
if r < 1:
bound = rounds_bound if isinstance(rounds_bound, int) else rounds_bound.value
if bound < 1:
return

varname = self.stmt.target.id
Expand All @@ -294,7 +302,10 @@ def _parse_For_range(self):
loop_body.append(["mstore", iptr, i])
loop_body.append(parse_body(self.stmt.body, self.context))

ir_node = IRnode.from_list(["repeat", i, start, rounds, rounds, loop_body])
# NOTE: codegen for `repeat` inserts an assertion that rounds <= rounds_bound.
# if we ever want to remove that, we need to manually add the assertion
# where it makes sense.
ir_node = IRnode.from_list(["repeat", i, start, rounds, rounds_bound, loop_body])
del self.context.forvars[varname]

return ir_node
Expand Down
3 changes: 1 addition & 2 deletions vyper/ir/compile_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,9 +413,8 @@ def _height_of(witharg):
)
# stack: i, rounds, rounds_bound
# assert rounds <= rounds_bound
# TODO this runtime assertion should never fail for
# TODO this runtime assertion shouldn't fail for
# internally generated repeats.
# maybe drop it or jump to 0xFE
o.extend(["DUP2", "GT"] + _assert_false())

# stack: i, rounds
Expand Down
3 changes: 3 additions & 0 deletions vyper/semantics/analysis/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def visit_For(self, node):
iter_type = node.target._metadata["type"]
for a in node.iter.args:
self.expr_visitor.visit(a, iter_type)
for a in node.iter.keywords:
if a.arg == "bound":
self.expr_visitor.visit(a.value, iter_type)


class ExpressionAnnotationVisitor(_AnnotationVisitorBase):
Expand Down
27 changes: 20 additions & 7 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,17 +346,30 @@ def visit_For(self, node):
raise IteratorException(
"Cannot iterate over the result of a function call", node.iter
)
validate_call_args(node.iter, (1, 2))
validate_call_args(node.iter, (1, 2), kwargs=["bound"])

args = node.iter.args
kwargs = {s.arg: s.value for s in node.iter.keywords or []}
if len(args) == 1:
# range(CONSTANT)
if not isinstance(args[0], vy_ast.Num):
raise StateAccessViolation("Value must be a literal", node)
if args[0].value <= 0:
raise StructureException("For loop must have at least 1 iteration", args[0])
validate_expected_type(args[0], IntegerT.any())
type_list = get_possible_types_from_node(args[0])
n = args[0]
bound = kwargs.pop("bound", None)
validate_expected_type(n, IntegerT.any())

if bound is None:
if not isinstance(n, vy_ast.Num):
raise StateAccessViolation("Value must be a literal", n)
if n.value <= 0:
raise StructureException("For loop must have at least 1 iteration", args[0])
type_list = get_possible_types_from_node(n)

else:
if not isinstance(bound, vy_ast.Num):
raise StateAccessViolation("bound must be a literal", bound)
if bound.value <= 0:
raise StructureException("bound must be at least 1", args[0])
type_list = get_common_types(n, bound)

else:
validate_expected_type(args[0], IntegerT.any())
type_list = get_common_types(*args)
Expand Down

0 comments on commit d48438e

Please sign in to comment.