From 5319cfbe14951e007ccdb323257e5ada869b35d5 Mon Sep 17 00:00:00 2001 From: Daniel Schiavini Date: Sun, 24 Dec 2023 17:10:45 +0100 Subject: [PATCH] feat: allow `range(x, y, bound=N)` (#3679) - allow range where both start and end arguments are variables, so long as a bound is supplied - ban range expressions of the form `range(x, x + N)` since the new form is cleaner and supersedes it. - also do a bit of refactoring of the codegen for range --------- Co-authored-by: Charles Cooper --- docs/control-structures.rst | 8 +- .../features/iteration/test_for_in_list.py | 19 +- .../features/iteration/test_for_range.py | 116 ++++++++++- .../codegen/integration/test_crowdfund.py | 4 +- .../test_invalid_literal_exception.py | 7 - tests/functional/syntax/test_for_range.py | 197 +++++++++++++++++- vyper/codegen/ir_node.py | 8 +- vyper/codegen/stmt.py | 67 +++--- vyper/exceptions.py | 2 +- vyper/semantics/analysis/local.py | 109 ++++------ 10 files changed, 390 insertions(+), 147 deletions(-) diff --git a/docs/control-structures.rst b/docs/control-structures.rst index 873135709a..2f890bcb2f 100644 --- a/docs/control-structures.rst +++ b/docs/control-structures.rst @@ -287,9 +287,11 @@ Another use of range can be with ``START`` and ``STOP`` bounds. Here, ``START`` and ``STOP`` are literal integers, with ``STOP`` being a greater value than ``START``. ``i`` begins as ``START`` and increments by one until it is equal to ``STOP``. +Finally, it is possible to use ``range`` with runtime `start` and `stop` values as long as a constant `bound` value is provided. +In this case, Vyper checks at runtime that `end - start <= bound`. +``N`` must be a compile-time constant. + .. code-block:: python - for i in range(a, a + N): + for i in range(start, end, bound=N): ... - -``a`` is a variable with an integer type and ``N`` is a literal integer greater than zero. ``i`` begins as ``a`` and increments by one until it is equal to ``a + N``. If ``a + N`` would overflow, execution will revert. diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index fb01cc98eb..bc1a12ae9e 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -1,3 +1,4 @@ +import re from decimal import Decimal import pytest @@ -700,13 +701,16 @@ def foo(): """, StateAccessViolation, ), - """ + ( + """ @external def foo(): a: int128 = 6 for i in range(a,a-3): pass """, + StateAccessViolation, + ), # invalid argument length ( """ @@ -789,10 +793,13 @@ def test_for() -> int128: ), ] +BAD_CODE = [code if isinstance(code, tuple) else (code, StructureException) for code in BAD_CODE] +for_code_regex = re.compile(r"for .+ in (.*):") +bad_code_names = [ + f"{i} {for_code_regex.search(code).group(1)}" for i, (code, _) in enumerate(BAD_CODE) +] + -@pytest.mark.parametrize("code", BAD_CODE) -def test_bad_code(assert_compile_failed, get_contract, code): - err = StructureException - if not isinstance(code, str): - code, err = code +@pytest.mark.parametrize("code,err", BAD_CODE, ids=bad_code_names) +def test_bad_code(assert_compile_failed, get_contract, code, err): assert_compile_failed(lambda: get_contract(code), err) diff --git a/tests/functional/codegen/features/iteration/test_for_range.py b/tests/functional/codegen/features/iteration/test_for_range.py index 96b83ae691..e946447285 100644 --- a/tests/functional/codegen/features/iteration/test_for_range.py +++ b/tests/functional/codegen/features/iteration/test_for_range.py @@ -32,6 +32,102 @@ def repeat(n: uint256) -> uint256: c.repeat(7) +def test_range_bound_constant_end(get_contract, tx_failed): + code = """ +@external +def repeat(n: uint256) -> uint256: + x: uint256 = 0 + for i in range(n, 7, bound=6): + x += i + 1 + return x + """ + c = get_contract(code) + for n in range(1, 5): + assert c.repeat(n) == sum(i + 1 for i in range(n, 7)) + + # check assertion for `start <= end` + with tx_failed(): + c.repeat(8) + # check assertion for `start + bound <= end` + with tx_failed(): + c.repeat(0) + + +def test_range_bound_two_args(get_contract, tx_failed): + code = """ +@external +def repeat(n: uint256) -> uint256: + x: uint256 = 0 + for i in range(1, n, bound=6): + x += i + 1 + return x + """ + c = get_contract(code) + for n in range(1, 8): + assert c.repeat(n) == sum(i + 1 for i in range(1, n)) + + # check assertion for `start <= end` + with tx_failed(): + c.repeat(0) + + # check codegen inserts assertion for `start + bound <= end` + with tx_failed(): + c.repeat(8) + + +def test_range_bound_two_runtime_args(get_contract, tx_failed): + code = """ +@external +def repeat(start: uint256, end: uint256) -> uint256: + x: uint256 = 0 + for i in range(start, end, bound=6): + x += i + return x + """ + c = get_contract(code) + for n in range(0, 7): + assert c.repeat(0, n) == sum(range(0, n)) + assert c.repeat(n, n * 2) == sum(range(n, n * 2)) + + # check assertion for `start <= end` + with tx_failed(): + c.repeat(1, 0) + with tx_failed(): + c.repeat(7, 0) + with tx_failed(): + c.repeat(8, 7) + + # check codegen inserts assertion for `start + bound <= end` + with tx_failed(): + c.repeat(0, 7) + with tx_failed(): + c.repeat(14, 21) + + +def test_range_overflow(get_contract, tx_failed): + code = """ +@external +def get_last(start: uint256, end: uint256) -> uint256: + x: uint256 = 0 + for i in range(start, end, bound=6): + x = i + return x + """ + c = get_contract(code) + UINT_MAX = 2**256 - 1 + assert c.get_last(UINT_MAX, UINT_MAX) == 0 # initial value of x + + for n in range(1, 6): + assert c.get_last(UINT_MAX - n, UINT_MAX) == UINT_MAX - 1 + + # check for `start + bound <= end`, overflow cases + for n in range(1, 7): + with tx_failed(): + c.get_last(UINT_MAX - n, 0) + with tx_failed(): + c.get_last(UINT_MAX, UINT_MAX - n) + + def test_digit_reverser(get_contract_with_gas_estimation): digit_reverser = """ @external @@ -89,7 +185,7 @@ def test_offset_repeater_2(get_contract_with_gas_estimation, typ): @external def sum(frm: {typ}, to: {typ}) -> {typ}: out: {typ} = 0 - for i in range(frm, frm + 101): + for i in range(frm, frm + 101, bound=101): if i == to: break out = out + i @@ -146,26 +242,28 @@ def foo(a: {typ}) -> {typ}: assert c.foo(100) == 31337 -# test that we can get to the upper range of an integer @pytest.mark.parametrize("typ", ["uint8", "int128", "uint256"]) def test_for_range_edge(get_contract, typ): + """ + Check that we can get to the upper range of an integer. + Note that to avoid overflow in the bounds check for range(), + we need to calculate i+1 inside the loop. + """ code = f""" @external def test(): found: bool = False x: {typ} = max_value({typ}) - for i in range(x, x + 1): - if i == max_value({typ}): + for i in range(x - 1, x, bound=1): + if i + 1 == max_value({typ}): found = True - assert found found = False x = max_value({typ}) - 1 - for i in range(x, x + 2): - if i == max_value({typ}): + for i in range(x - 1, x + 1, bound=2): + if i + 1 == max_value({typ}): found = True - assert found """ c = get_contract(code) @@ -178,7 +276,7 @@ def test_for_range_oob_check(get_contract, tx_failed, typ): @external def test(): x: {typ} = max_value({typ}) - for i in range(x, x+2): + for i in range(x, x + 2, bound=2): pass """ c = get_contract(code) diff --git a/tests/functional/codegen/integration/test_crowdfund.py b/tests/functional/codegen/integration/test_crowdfund.py index 2083e62610..671d424d60 100644 --- a/tests/functional/codegen/integration/test_crowdfund.py +++ b/tests/functional/codegen/integration/test_crowdfund.py @@ -52,7 +52,7 @@ def finalize(): @external def refund(): ind: int128 = self.refundIndex - for i in range(ind, ind + 30): + for i in range(ind, ind + 30, bound=30): if i >= self.nextFunderIndex: self.refundIndex = self.nextFunderIndex return @@ -147,7 +147,7 @@ def finalize(): @external def refund(): ind: int128 = self.refundIndex - for i in range(ind, ind + 30): + for i in range(ind, ind + 30, bound=30): if i >= self.nextFunderIndex: self.refundIndex = self.nextFunderIndex return diff --git a/tests/functional/syntax/exceptions/test_invalid_literal_exception.py b/tests/functional/syntax/exceptions/test_invalid_literal_exception.py index 1f4f112252..a0cf10ad02 100644 --- a/tests/functional/syntax/exceptions/test_invalid_literal_exception.py +++ b/tests/functional/syntax/exceptions/test_invalid_literal_exception.py @@ -18,13 +18,6 @@ def foo(): """, """ @external -def foo(x: int128): - y: int128 = 7 - for i in range(x, x + y): - pass - """, - """ -@external def foo(): x: String[100] = "these bytes are nо gооd because the o's are from the Russian alphabet" """, diff --git a/tests/functional/syntax/test_for_range.py b/tests/functional/syntax/test_for_range.py index e6f35c1d2d..7c7f9c476d 100644 --- a/tests/functional/syntax/test_for_range.py +++ b/tests/functional/syntax/test_for_range.py @@ -1,7 +1,9 @@ +import re + import pytest from vyper import compiler -from vyper.exceptions import StructureException +from vyper.exceptions import ArgumentException, StateAccessViolation, StructureException fail_list = [ ( @@ -12,33 +14,191 @@ def foo(): pass """, StructureException, + "Invalid syntax for loop iterator", + "a[1]", + ), + ( + """ +@external +def foo(): + x: uint256 = 100 + for _ in range(10, bound=x): + pass + """, + StateAccessViolation, + "Bound must be a literal", + "x", + ), + ( + """ +@external +def foo(): + for _ in range(10, 20, bound=5): + pass + """, + StructureException, + "Please remove the `bound=` kwarg when using range with constants", + "5", + ), + ( + """ +@external +def foo(): + for _ in range(10, 20, bound=0): + pass + """, + StructureException, + "Bound must be at least 1", + "0", + ), + ( + """ +@external +def bar(): + x:uint256 = 1 + for i in range(x,x+1,bound=2,extra=3): + pass + """, + ArgumentException, + "Invalid keyword argument 'extra'", + "extra=3", ), ( """ @external def bar(): - for i in range(1,2,bound=2): + for i in range(0): pass """, StructureException, + "End must be greater than start", + "0", + ), + ( + """ +@external +def bar(): + x:uint256 = 1 + for i in range(x): + pass + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "x", + ), + ( + """ +@external +def bar(): + x:uint256 = 1 + for i in range(0, x): + pass + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "x", + ), + ( + """ +@external +def repeat(n: uint256) -> uint256: + for i in range(0, n * 10): + pass + return n + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "n * 10", ), ( """ @external def bar(): x:uint256 = 1 - for i in range(x,x+1,bound=2): + for i in range(0, x + 1): + pass + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "x + 1", + ), + ( + """ +@external +def bar(): + for i in range(2, 1): pass """, StructureException, + "End must be greater than start", + "1", + ), + ( + """ +@external +def bar(): + x:uint256 = 1 + for i in range(x, x): + pass + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "x", + ), + ( + """ +@external +def foo(): + x: int128 = 5 + for i in range(x, x + 10): + pass + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "x", + ), + ( + """ +@external +def repeat(n: uint256) -> uint256: + for i in range(n, 6): + pass + return x + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "n", + ), + ( + """ +@external +def foo(x: int128): + y: int128 = 7 + for i in range(x, x + y): + pass + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "x", ), ] +for_code_regex = re.compile(r"for .+ in (.*):") +fail_test_names = [ + ( + f"{i:02d}: {for_code_regex.search(code).group(1)}" # type: ignore[union-attr] + f" raises {type(err).__name__}" + ) + for i, (code, err, msg, src) in enumerate(fail_list) +] -@pytest.mark.parametrize("bad_code", fail_list) -def test_range_fail(bad_code): - with pytest.raises(bad_code[1]): - compiler.compile_code(bad_code[0]) + +@pytest.mark.parametrize("bad_code,error_type,message,source_code", fail_list, ids=fail_test_names) +def test_range_fail(bad_code, error_type, message, source_code): + with pytest.raises(error_type) as exc_info: + compiler.compile_code(bad_code) + assert message == exc_info.value.message + assert source_code == exc_info.value.args[1].node_source_code valid_list = [ @@ -58,7 +218,21 @@ def foo(): @external def foo(): x: int128 = 5 - for i in range(x, x + 10): + for i in range(1, x, bound=4): + pass + """, + """ +@external +def foo(): + x: int128 = 5 + for i in range(x, bound=4): + pass + """, + """ +@external +def foo(): + x: int128 = 5 + for i in range(0, x, bound=4): pass """, """ @@ -72,7 +246,12 @@ def kick_foos(): """, ] +valid_test_names = [ + f"{i} {for_code_regex.search(code).group(1)}" # type: ignore[union-attr] + for i, code in enumerate(valid_list) +] + -@pytest.mark.parametrize("good_code", valid_list) +@pytest.mark.parametrize("good_code", valid_list, ids=valid_test_names) def test_range_success(good_code): assert compiler.compile_code(good_code) is not None diff --git a/vyper/codegen/ir_node.py b/vyper/codegen/ir_node.py index ce26066968..45d93f3067 100644 --- a/vyper/codegen/ir_node.py +++ b/vyper/codegen/ir_node.py @@ -444,11 +444,15 @@ def unique_symbols(self): return ret @property - def is_literal(self): + def is_literal(self) -> bool: return isinstance(self.value, int) or self.value == "multi" + def int_value(self) -> int: + assert isinstance(self.value, int) + return self.value + @property - def is_pointer(self): + def is_pointer(self) -> bool: # not used yet but should help refactor/clarify downstream code # eventually return self.location is not None diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 601597771c..18e5c3d494 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -225,15 +225,6 @@ def parse_Raise(self): else: return IRnode.from_list(["revert", 0, 0], error_msg="user raise") - def _check_valid_range_constant(self, arg_ast_node): - with self.context.range_scope(): - arg_expr = Expr.parse_value_expr(arg_ast_node, self.context) - return arg_expr - - def _get_range_const_value(self, arg_ast_node): - arg_expr = self._check_valid_range_constant(arg_ast_node) - return arg_expr.value - def parse_For(self): with self.context.block_scope(): if self.stmt.get("iter.func.id") == "range": @@ -249,41 +240,37 @@ def _parse_For_range(self): iter_typ = INT256_T # Get arg0 - arg0 = self.stmt.iter.args[0] - num_of_args = len(self.stmt.iter.args) - - kwargs = { - s.arg: Expr.parse_value_expr(s.value, self.context) - for s in self.stmt.iter.keywords or [] - } - - # Type 1 for, e.g. for i in range(10): ... - if num_of_args == 1: - n = Expr.parse_value_expr(arg0, self.context) - start = IRnode.from_list(0, typ=iter_typ) - rounds = n - rounds_bound = kwargs.get("bound", rounds) - - # Type 2 for, e.g. for i in range(100, 110): ... - elif self._check_valid_range_constant(self.stmt.iter.args[1]).is_literal: - arg0_val = self._get_range_const_value(arg0) - arg1_val = self._get_range_const_value(self.stmt.iter.args[1]) - start = IRnode.from_list(arg0_val, typ=iter_typ) - rounds = IRnode.from_list(arg1_val - arg0_val, typ=iter_typ) - rounds_bound = rounds + for_iter: vy_ast.Call = self.stmt.iter + args_len = len(for_iter.args) + if args_len == 1: + arg0, arg1 = (IRnode.from_list(0, typ=iter_typ), for_iter.args[0]) + elif args_len == 2: + arg0, arg1 = for_iter.args + else: # pragma: nocover + raise TypeCheckFailure("unreachable: bad # of arguments to range()") - # Type 3 for, e.g. for i in range(x, x + 10): ... - else: - arg1 = self.stmt.iter.args[1] - rounds = self._get_range_const_value(arg1.right) + with self.context.range_scope(): start = Expr.parse_value_expr(arg0, self.context) - _, hi = start.typ.int_bounds - start = clamp("le", start, hi + 1 - rounds) + end = Expr.parse_value_expr(arg1, self.context) + kwargs = { + s.arg: Expr.parse_value_expr(s.value, self.context) for s in for_iter.keywords + } + + if "bound" in kwargs: + with end.cache_when_complex("end") as (b1, end): + # note: the check for rounds<=rounds_bound happens in asm + # generation for `repeat`. + clamped_start = clamp("le", start, end) + rounds = b1.resolve(IRnode.from_list(["sub", end, clamped_start])) + rounds_bound = kwargs.pop("bound").int_value() + else: + rounds = end.int_value() - start.int_value() rounds_bound = rounds - bound = rounds_bound if isinstance(rounds_bound, int) else rounds_bound.value - if bound < 1: - return + assert len(kwargs) == 0 # sanity check stray keywords + + if rounds_bound < 1: # pragma: nocover + raise TypeCheckFailure("unreachable: unchecked 0 bound") varname = self.stmt.target.id i = IRnode.from_list(self.context.fresh_varname("range_ix"), typ=UINT256_T) diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 8f72d9afc9..8921814188 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -41,7 +41,7 @@ def __init__(self, message="Error Message not found.", *items): Error message to display with the exception. *items : VyperNode | Tuple[str, VyperNode], optional Vyper ast node(s), or tuple of (description, node) indicating where - the exception occured. Source annotations are generated in the order + the exception occurred. Source annotations are generated in the order the nodes are given. A single tuple of (lineno, col_offset) is also understood to support diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 2a84f69ad4..a3ebf85fa2 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -7,7 +7,6 @@ ExceptionList, FunctionDeclarationException, ImmutableViolation, - InvalidLiteral, InvalidOperation, InvalidType, IteratorException, @@ -355,71 +354,7 @@ def visit_For(self, node): raise IteratorException( "Cannot iterate over the result of a function call", node.iter ) - range_ = node.iter - validate_call_args(range_, (1, 2), kwargs=["bound"]) - - args = range_.args - kwargs = {s.arg: s.value for s in range_.keywords or []} - if len(args) == 1: - # range(CONSTANT) - n = args[0] - bound = kwargs.pop("bound", None) - validate_expected_type(n, IntegerT.any()) - - if bound is None: - if not isinstance(n, vy_ast.Num): - raise StateAccessViolation("Value must be a literal", n) - if n.value <= 0: - raise StructureException("For loop must have at least 1 iteration", args[0]) - type_list = get_possible_types_from_node(n) - - else: - if not isinstance(bound, vy_ast.Num): - raise StateAccessViolation("bound must be a literal", bound) - if bound.value <= 0: - raise StructureException("bound must be at least 1", args[0]) - type_list = get_common_types(n, bound) - - else: - if range_.keywords: - raise StructureException( - "Keyword arguments are not supported for `range(N, M)` and" - "`range(x, x + N)` expressions", - range_.keywords[0], - ) - - validate_expected_type(args[0], IntegerT.any()) - type_list = get_common_types(*args) - if not isinstance(args[0], vy_ast.Constant): - # range(x, x + CONSTANT) - if not isinstance(args[1], vy_ast.BinOp) or not isinstance( - args[1].op, vy_ast.Add - ): - raise StructureException( - "Second element must be the first element plus a literal value", args[0] - ) - if not vy_ast.compare_nodes(args[0], args[1].left): - raise StructureException( - "First and second variable must be the same", args[1].left - ) - if not isinstance(args[1].right, vy_ast.Int): - raise InvalidLiteral("Literal must be an integer", args[1].right) - if args[1].right.value < 1: - raise StructureException( - f"For loop has invalid number of iterations ({args[1].right.value})," - " the value must be greater than zero", - args[1].right, - ) - else: - # range(CONSTANT, CONSTANT) - if not isinstance(args[1], vy_ast.Int): - raise InvalidType("Value must be a literal integer", args[1]) - validate_expected_type(args[1], IntegerT.any()) - if args[0].value >= args[1].value: - raise StructureException("Second value must be > first value", args[1]) - - if not type_list: - raise TypeMismatch("Iterator values are of different types", node.iter) + type_list = _analyse_range_call(node.iter) else: # iteration over a variable or literal list @@ -490,8 +425,8 @@ def visit_For(self, node): try: with NodeMetadata.enter_typechecker_speculation(): - for n in node.body: - self.visit(n) + for stmt in node.body: + self.visit(stmt) except (TypeMismatch, InvalidOperation) as exc: for_loop_exceptions.append(exc) else: @@ -801,3 +736,41 @@ def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None: self.visit(node.body, typ) validate_expected_type(node.orelse, typ) self.visit(node.orelse, typ) + + +def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: + """ + Check that the arguments to a range() call are valid. + :param node: call to range() + :return: None + """ + validate_call_args(node, (1, 2), kwargs=["bound"]) + kwargs = {s.arg: s.value for s in node.keywords or []} + start, end = (vy_ast.Int(value=0), node.args[0]) if len(node.args) == 1 else node.args + + all_args = (start, end, *kwargs.values()) + for arg1 in all_args: + validate_expected_type(arg1, IntegerT.any()) + + type_list = get_common_types(*all_args) + if not type_list: + raise TypeMismatch("Iterator values are of different types", node) + + if "bound" in kwargs: + bound = kwargs["bound"] + if not isinstance(bound, vy_ast.Num): + raise StateAccessViolation("Bound must be a literal", bound) + if bound.value <= 0: + raise StructureException("Bound must be at least 1", bound) + if isinstance(start, vy_ast.Num) and isinstance(end, vy_ast.Num): + error = "Please remove the `bound=` kwarg when using range with constants" + raise StructureException(error, bound) + else: + for arg in (start, end): + if not isinstance(arg, vy_ast.Num): + error = "Value must be a literal integer, unless a bound is specified" + raise StateAccessViolation(error, arg) + if end.value <= start.value: + raise StructureException("End must be greater than start", end) + + return type_list