diff --git a/tests/functional/builtins/codegen/test_convert.py b/tests/functional/builtins/codegen/test_convert.py index ad1a616300..27a3d0d519 100644 --- a/tests/functional/builtins/codegen/test_convert.py +++ b/tests/functional/builtins/codegen/test_convert.py @@ -17,6 +17,7 @@ checksum_encode, int_bounds, is_checksum_encoded, + quantize, round_towards_zero, unsigned_to_signed, ) @@ -414,7 +415,7 @@ def _vyper_literal(val, typ): return "0x" + val.hex() if isinstance(typ, DecimalT): tmp = val - val = val.quantize(DECIMAL_EPSILON) + val = quantize(val) assert tmp == val return str(val) diff --git a/tests/functional/codegen/types/numbers/test_decimals.py b/tests/functional/codegen/types/numbers/test_decimals.py index 425440fd4b..04cf267e4c 100644 --- a/tests/functional/codegen/types/numbers/test_decimals.py +++ b/tests/functional/codegen/types/numbers/test_decimals.py @@ -1,5 +1,5 @@ import warnings -from decimal import ROUND_DOWN, Decimal, getcontext +from decimal import Decimal, getcontext import pytest @@ -10,7 +10,7 @@ OverflowException, TypeMismatch, ) -from vyper.utils import DECIMAL_EPSILON, SizeLimits +from vyper.utils import DECIMAL_EPSILON, SizeLimits, quantize def test_decimal_override(): @@ -51,10 +51,6 @@ def foo(x: decimal) -> decimal: compile_code(code) -def quantize(x: Decimal) -> Decimal: - return x.quantize(DECIMAL_EPSILON, rounding=ROUND_DOWN) - - def test_decimal_test(get_contract_with_gas_estimation): decimal_test = """ @external diff --git a/tests/unit/ast/nodes/test_fold_binop_decimal.py b/tests/unit/ast/nodes/test_fold_binop_decimal.py index a75d114f88..837861d010 100644 --- a/tests/unit/ast/nodes/test_fold_binop_decimal.py +++ b/tests/unit/ast/nodes/test_fold_binop_decimal.py @@ -6,9 +6,17 @@ from tests.utils import parse_and_fold from vyper.exceptions import OverflowException, TypeMismatch, ZeroDivisionException +from vyper.semantics.analysis.local import ExprVisitor +from vyper.semantics.types import DecimalT + +DECIMAL_T = DecimalT() st_decimals = st.decimals( - min_value=-(2**32), max_value=2**32, allow_nan=False, allow_infinity=False, places=10 + min_value=DECIMAL_T.decimal_bounds[0], + max_value=DECIMAL_T.decimal_bounds[1], + allow_nan=False, + allow_infinity=False, + places=DECIMAL_T._decimal_places, ) @@ -30,10 +38,11 @@ def foo(a: decimal, b: decimal) -> decimal: try: vyper_ast = parse_and_fold(f"{left} {op} {right}") - old_node = vyper_ast.body[0].value - new_node = old_node.get_folded_value() + expr = vyper_ast.body[0].value + ExprVisitor().visit(expr, DecimalT()) + new_node = expr.get_folded_value() is_valid = True - except ZeroDivisionException: + except (OverflowException, ZeroDivisionException): is_valid = False if is_valid: @@ -71,9 +80,12 @@ def foo({input_value}) -> decimal: literal_op = literal_op.rsplit(maxsplit=1)[0] try: vyper_ast = parse_and_fold(literal_op) - new_node = vyper_ast.body[0].value.get_folded_value() + expr = vyper_ast.body[0].value + ExprVisitor().visit(expr, DecimalT()) + new_node = expr.get_folded_value() expected = new_node.value - is_valid = -(2**127) <= expected < 2**127 + lo, hi = DecimalT().decimal_bounds + is_valid = lo <= expected < hi except (OverflowException, ZeroDivisionException): # for overflow or division/modulus by 0, expect the contract call to revert is_valid = False diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 4ba2d1a593..c78ecb6d89 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -26,7 +26,14 @@ VyperException, ZeroDivisionException, ) -from vyper.utils import MAX_DECIMAL_PLACES, SizeLimits, annotate_source_code, evm_div, sha256sum +from vyper.utils import ( + MAX_DECIMAL_PLACES, + SizeLimits, + annotate_source_code, + evm_div, + quantize, + sha256sum, +) NODE_BASE_ATTRIBUTES = ( "_children", @@ -824,6 +831,7 @@ def to_dict(self): return ast_dict def validate(self): + # note: maybe use self.value == quantize(self.value) for this check if self.value.as_tuple().exponent < -MAX_DECIMAL_PLACES: raise InvalidLiteral("Vyper supports a maximum of ten decimal points", self) if self.value < SizeLimits.MIN_AST_DECIMAL: @@ -1010,9 +1018,15 @@ def _op(self, left, right): value = left * right if isinstance(left, decimal.Decimal): # ensure that the result is truncated to MAX_DECIMAL_PLACES - return value.quantize( - decimal.Decimal(f"{1:0.{MAX_DECIMAL_PLACES}f}"), decimal.ROUND_DOWN - ) + try: + # if the intermediate result requires too many decimal places, + # decimal will puke - catch the error and raise an + # OverflowException + return quantize(value) + except decimal.InvalidOperation: + msg = f"{self._description} requires too many decimal places:" + msg += f"\n {left} * {right} => {value}" + raise OverflowException(msg, self) from None else: return value @@ -1036,7 +1050,12 @@ def _op(self, left, right): # the EVM always truncates toward zero value = -(-left / right) # ensure that the result is truncated to MAX_DECIMAL_PLACES - return value.quantize(decimal.Decimal(f"{1:0.{MAX_DECIMAL_PLACES}f}"), decimal.ROUND_DOWN) + try: + return quantize(value) + except decimal.InvalidOperation: + msg = f"{self._description} requires too many decimal places:" + msg += f"\n {left} {self._pretty} {right} => {value}" + raise OverflowException(msg, self) from None class FloorDiv(VyperNode): diff --git a/vyper/utils.py b/vyper/utils.py index 114ddf97c2..cf8a709997 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -407,6 +407,11 @@ class SizeLimits: MAX_UINT256 = 2**256 - 1 +def quantize(d: decimal.Decimal, places=MAX_DECIMAL_PLACES, rounding_mode=decimal.ROUND_DOWN): + quantizer = decimal.Decimal(f"{1:0.{places}f}") + return d.quantize(quantizer, rounding_mode) + + # List of valid IR macros. # TODO move this somewhere else, like ir_node.py VALID_IR_MACROS = {