From ee11e3db7b2d8cd96fcb24940406f7284bddf535 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 4 Apr 2024 11:57:00 -0400 Subject: [PATCH] fix[test]: fix a bad bound in decimal fuzzing (#3909) the decimal fuzz test `is_valid` condition was based on an ancient version of decimals which had bounds at `-2**127` and `2**127`. update the condition to be compatible with the latest version of `decimal`. also increase the range of decimals produced by the decimal fuzzing strategy, so that the fuzzer finds overflow issues faster. an additional issue was found in the fuzz tests, which is that some decimal operations panic with `decimal.InvalidOperation` instead of a proper exception. this is a known bug, see GH #2241. this fixes the issue by catching the exception and raising an `OverflowException`. misc/refactor: - refactor several uses of quantize into a utility function --- .../builtins/codegen/test_convert.py | 3 +- .../codegen/types/numbers/test_decimals.py | 8 ++--- .../unit/ast/nodes/test_fold_binop_decimal.py | 24 +++++++++++---- vyper/ast/nodes.py | 29 +++++++++++++++---- vyper/utils.py | 5 ++++ 5 files changed, 51 insertions(+), 18 deletions(-) diff --git a/tests/functional/builtins/codegen/test_convert.py b/tests/functional/builtins/codegen/test_convert.py index ad1a616300..27a3d0d519 100644 --- a/tests/functional/builtins/codegen/test_convert.py +++ b/tests/functional/builtins/codegen/test_convert.py @@ -17,6 +17,7 @@ checksum_encode, int_bounds, is_checksum_encoded, + quantize, round_towards_zero, unsigned_to_signed, ) @@ -414,7 +415,7 @@ def _vyper_literal(val, typ): return "0x" + val.hex() if isinstance(typ, DecimalT): tmp = val - val = val.quantize(DECIMAL_EPSILON) + val = quantize(val) assert tmp == val return str(val) diff --git a/tests/functional/codegen/types/numbers/test_decimals.py b/tests/functional/codegen/types/numbers/test_decimals.py index 425440fd4b..04cf267e4c 100644 --- a/tests/functional/codegen/types/numbers/test_decimals.py +++ b/tests/functional/codegen/types/numbers/test_decimals.py @@ -1,5 +1,5 @@ import warnings -from decimal import ROUND_DOWN, Decimal, getcontext +from decimal import Decimal, getcontext import pytest @@ -10,7 +10,7 @@ OverflowException, TypeMismatch, ) -from vyper.utils import DECIMAL_EPSILON, SizeLimits +from vyper.utils import DECIMAL_EPSILON, SizeLimits, quantize def test_decimal_override(): @@ -51,10 +51,6 @@ def foo(x: decimal) -> decimal: compile_code(code) -def quantize(x: Decimal) -> Decimal: - return x.quantize(DECIMAL_EPSILON, rounding=ROUND_DOWN) - - def test_decimal_test(get_contract_with_gas_estimation): decimal_test = """ @external diff --git a/tests/unit/ast/nodes/test_fold_binop_decimal.py b/tests/unit/ast/nodes/test_fold_binop_decimal.py index a75d114f88..837861d010 100644 --- a/tests/unit/ast/nodes/test_fold_binop_decimal.py +++ b/tests/unit/ast/nodes/test_fold_binop_decimal.py @@ -6,9 +6,17 @@ from tests.utils import parse_and_fold from vyper.exceptions import OverflowException, TypeMismatch, ZeroDivisionException +from vyper.semantics.analysis.local import ExprVisitor +from vyper.semantics.types import DecimalT + +DECIMAL_T = DecimalT() st_decimals = st.decimals( - min_value=-(2**32), max_value=2**32, allow_nan=False, allow_infinity=False, places=10 + min_value=DECIMAL_T.decimal_bounds[0], + max_value=DECIMAL_T.decimal_bounds[1], + allow_nan=False, + allow_infinity=False, + places=DECIMAL_T._decimal_places, ) @@ -30,10 +38,11 @@ def foo(a: decimal, b: decimal) -> decimal: try: vyper_ast = parse_and_fold(f"{left} {op} {right}") - old_node = vyper_ast.body[0].value - new_node = old_node.get_folded_value() + expr = vyper_ast.body[0].value + ExprVisitor().visit(expr, DecimalT()) + new_node = expr.get_folded_value() is_valid = True - except ZeroDivisionException: + except (OverflowException, ZeroDivisionException): is_valid = False if is_valid: @@ -71,9 +80,12 @@ def foo({input_value}) -> decimal: literal_op = literal_op.rsplit(maxsplit=1)[0] try: vyper_ast = parse_and_fold(literal_op) - new_node = vyper_ast.body[0].value.get_folded_value() + expr = vyper_ast.body[0].value + ExprVisitor().visit(expr, DecimalT()) + new_node = expr.get_folded_value() expected = new_node.value - is_valid = -(2**127) <= expected < 2**127 + lo, hi = DecimalT().decimal_bounds + is_valid = lo <= expected < hi except (OverflowException, ZeroDivisionException): # for overflow or division/modulus by 0, expect the contract call to revert is_valid = False diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 4ba2d1a593..c78ecb6d89 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -26,7 +26,14 @@ VyperException, ZeroDivisionException, ) -from vyper.utils import MAX_DECIMAL_PLACES, SizeLimits, annotate_source_code, evm_div, sha256sum +from vyper.utils import ( + MAX_DECIMAL_PLACES, + SizeLimits, + annotate_source_code, + evm_div, + quantize, + sha256sum, +) NODE_BASE_ATTRIBUTES = ( "_children", @@ -824,6 +831,7 @@ def to_dict(self): return ast_dict def validate(self): + # note: maybe use self.value == quantize(self.value) for this check if self.value.as_tuple().exponent < -MAX_DECIMAL_PLACES: raise InvalidLiteral("Vyper supports a maximum of ten decimal points", self) if self.value < SizeLimits.MIN_AST_DECIMAL: @@ -1010,9 +1018,15 @@ def _op(self, left, right): value = left * right if isinstance(left, decimal.Decimal): # ensure that the result is truncated to MAX_DECIMAL_PLACES - return value.quantize( - decimal.Decimal(f"{1:0.{MAX_DECIMAL_PLACES}f}"), decimal.ROUND_DOWN - ) + try: + # if the intermediate result requires too many decimal places, + # decimal will puke - catch the error and raise an + # OverflowException + return quantize(value) + except decimal.InvalidOperation: + msg = f"{self._description} requires too many decimal places:" + msg += f"\n {left} * {right} => {value}" + raise OverflowException(msg, self) from None else: return value @@ -1036,7 +1050,12 @@ def _op(self, left, right): # the EVM always truncates toward zero value = -(-left / right) # ensure that the result is truncated to MAX_DECIMAL_PLACES - return value.quantize(decimal.Decimal(f"{1:0.{MAX_DECIMAL_PLACES}f}"), decimal.ROUND_DOWN) + try: + return quantize(value) + except decimal.InvalidOperation: + msg = f"{self._description} requires too many decimal places:" + msg += f"\n {left} {self._pretty} {right} => {value}" + raise OverflowException(msg, self) from None class FloorDiv(VyperNode): diff --git a/vyper/utils.py b/vyper/utils.py index 114ddf97c2..cf8a709997 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -407,6 +407,11 @@ class SizeLimits: MAX_UINT256 = 2**256 - 1 +def quantize(d: decimal.Decimal, places=MAX_DECIMAL_PLACES, rounding_mode=decimal.ROUND_DOWN): + quantizer = decimal.Decimal(f"{1:0.{places}f}") + return d.quantize(quantizer, rounding_mode) + + # List of valid IR macros. # TODO move this somewhere else, like ir_node.py VALID_IR_MACROS = {