From afac8c4bf036583a66f40ee2845df3334be4fd77 Mon Sep 17 00:00:00 2001 From: Simone Date: Fri, 4 Oct 2024 18:12:00 +0200 Subject: [PATCH 1/5] Improve values extraction --- slither/printers/guidance/echidna.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/slither/printers/guidance/echidna.py b/slither/printers/guidance/echidna.py index 35a609193..b87fc76e6 100644 --- a/slither/printers/guidance/echidna.py +++ b/slither/printers/guidance/echidna.py @@ -226,6 +226,9 @@ def _extract_constants_from_irs( # pylint: disable=too-many-branches,too-many-n # Do not report struct_name in a.struct_name if isinstance(ir, Member): continue + if isinstance(var_read, Variable) and var_read.is_constant: + value = ConstantFolding(var_read.expression, var_read.type).result() + all_cst_used.append(ConstantValue(str(value), str(var_read.type))) if isinstance(var_read, Constant): all_cst_used.append(ConstantValue(str(var_read.value), str(var_read.type))) if isinstance(var_read, StateVariable): From 48d7d9beef3d2dcc54aff6a2ea7ab7656d05b2a0 Mon Sep 17 00:00:00 2001 From: Simone Date: Mon, 7 Oct 2024 17:44:13 +0200 Subject: [PATCH 2/5] Improve support for type .max/.min and Enums --- .../visitors/expression/constants_folding.py | 122 +++++++++++++++++- 1 file changed, 116 insertions(+), 6 deletions(-) diff --git a/slither/visitors/expression/constants_folding.py b/slither/visitors/expression/constants_folding.py index b1fa570c6..c8cfeb716 100644 --- a/slither/visitors/expression/constants_folding.py +++ b/slither/visitors/expression/constants_folding.py @@ -13,7 +13,9 @@ TupleExpression, TypeConversion, CallExpression, + MemberAccess, ) +from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression from slither.core.variables import Variable from slither.utils.integer_conversion import convert_string_to_fraction, convert_string_to_int from slither.visitors.expression.expression import ExpressionVisitor @@ -27,7 +29,13 @@ class NotConstant(Exception): KEY = "ConstantFolding" CONSTANT_TYPES_OPERATIONS = Union[ - Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion + Literal, + BinaryOperation, + UnaryOperation, + Identifier, + TupleExpression, + TypeConversion, + MemberAccess, ] @@ -69,6 +77,7 @@ def result(self) -> "Literal": # pylint: disable=import-outside-toplevel def _post_identifier(self, expression: Identifier) -> None: from slither.core.declarations.solidity_variables import SolidityFunction + from slither.core.declarations.enum import Enum if isinstance(expression.value, Variable): if expression.value.is_constant: @@ -77,7 +86,14 @@ def _post_identifier(self, expression: Identifier) -> None: # Everything outside of literal if isinstance( expr, - (BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion), + ( + BinaryOperation, + UnaryOperation, + Identifier, + TupleExpression, + TypeConversion, + MemberAccess, + ), ): cf = ConstantFolding(expr, self._type) expr = cf.result() @@ -88,7 +104,10 @@ def _post_identifier(self, expression: Identifier) -> None: elif isinstance(expression.value, SolidityFunction): set_val(expression, expression.value) else: - raise NotConstant + # We don't want to raise an error for a direct access to an Enum as they can be converted to a constant value + # We can't handle it here because we don't have the field accessed so we do it in _post_member_access + if not isinstance(expression.value, Enum): + raise NotConstant # pylint: disable=too-many-branches,too-many-statements def _post_binary_operation(self, expression: BinaryOperation) -> None: @@ -96,12 +115,28 @@ def _post_binary_operation(self, expression: BinaryOperation) -> None: expression_right = expression.expression_right if not isinstance( expression_left, - (Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion), + ( + Literal, + BinaryOperation, + UnaryOperation, + Identifier, + TupleExpression, + TypeConversion, + MemberAccess, + ), ): raise NotConstant if not isinstance( expression_right, - (Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion), + ( + Literal, + BinaryOperation, + UnaryOperation, + Identifier, + TupleExpression, + TypeConversion, + MemberAccess, + ), ): raise NotConstant left = get_val(expression_left) @@ -205,6 +240,22 @@ def _post_assignement_operation(self, expression: expressions.AssignmentOperatio raise NotConstant def _post_call_expression(self, expression: expressions.CallExpression) -> None: + from slither.core.declarations.solidity_variables import SolidityFunction + from slither.core.declarations.enum import Enum + + # pylint: disable=too-many-boolean-expressions + if ( + isinstance(expression.called, Identifier) + and expression.called.value == SolidityFunction("type()") + and len(expression.arguments) == 1 + and ( + isinstance(expression.arguments[0], ElementaryTypeNameExpression) + or isinstance(expression.arguments[0], Identifier) + and isinstance(expression.arguments[0].value, Enum) + ) + ): + # Returning early to support type(ElemType).max/min or type(MyEnum).max/min + return called = get_val(expression.called) args = [get_val(arg) for arg in expression.arguments] if called.name == "keccak256(bytes)": @@ -220,12 +271,70 @@ def _post_conditional_expression(self, expression: expressions.ConditionalExpres def _post_elementary_type_name_expression( self, expression: expressions.ElementaryTypeNameExpression ) -> None: - raise NotConstant + # We don't have to raise an exception to support type(uint112).max or similar + pass def _post_index_access(self, expression: expressions.IndexAccess) -> None: raise NotConstant def _post_member_access(self, expression: expressions.MemberAccess) -> None: + from slither.core.declarations import ( + SolidityFunction, + Contract, + EnumContract, + EnumTopLevel, + Enum, + ) + from slither.core.solidity_types import UserDefinedType + + # pylint: disable=too-many-nested-blocks + if isinstance(expression.expression, CallExpression) and expression.member_name in [ + "min", + "max", + ]: + if isinstance(expression.expression.called, Identifier): + if expression.expression.called.value == SolidityFunction("type()"): + assert len(expression.expression.arguments) == 1 + type_expression_found = expression.expression.arguments[0] + type_found: Union[ElementaryType, UserDefinedType] + if isinstance(type_expression_found, ElementaryTypeNameExpression): + type_expression_found_type = type_expression_found.type + assert isinstance(type_expression_found_type, ElementaryType) + type_found = type_expression_found_type + value = ( + type_found.max if expression.member_name == "max" else type_found.min + ) + set_val(expression, value) + return + # type(enum).max/min + # Case when enum is in another contract e.g. type(C.E).max + if isinstance(type_expression_found, MemberAccess): + contract = type_expression_found.expression.value + assert isinstance(contract, Contract) + for enum in contract.enums: + if enum.name == type_expression_found.member_name: + type_found_in_expression = enum + type_found = UserDefinedType(enum) + break + else: + assert isinstance(type_expression_found, Identifier) + type_found_in_expression = type_expression_found.value + assert isinstance(type_found_in_expression, (EnumContract, EnumTopLevel)) + type_found = UserDefinedType(type_found_in_expression) + value = ( + type_found_in_expression.max + if expression.member_name == "max" + else type_found_in_expression.min + ) + set_val(expression, value) + return + elif isinstance(expression.expression, Identifier) and isinstance( + expression.expression.value, Enum + ): + # Handle direct access to enum field + set_val(expression, expression.expression.value.values.index(expression.member_name)) + return + raise NotConstant def _post_new_array(self, expression: expressions.NewArray) -> None: @@ -272,6 +381,7 @@ def _post_type_conversion(self, expression: expressions.TypeConversion) -> None: TupleExpression, TypeConversion, CallExpression, + MemberAccess, ), ): raise NotConstant From f93b65fef3bc7e6cebccd444215d9579e0bc22bd Mon Sep 17 00:00:00 2001 From: Simone Date: Mon, 7 Oct 2024 23:09:01 +0200 Subject: [PATCH 3/5] Fix for user defined type and variables defined in another contract --- slither/printers/guidance/echidna.py | 12 ++++- .../visitors/expression/constants_folding.py | 48 +++++++++++++++++-- 2 files changed, 55 insertions(+), 5 deletions(-) diff --git a/slither/printers/guidance/echidna.py b/slither/printers/guidance/echidna.py index b87fc76e6..057e78269 100644 --- a/slither/printers/guidance/echidna.py +++ b/slither/printers/guidance/echidna.py @@ -227,8 +227,16 @@ def _extract_constants_from_irs( # pylint: disable=too-many-branches,too-many-n if isinstance(ir, Member): continue if isinstance(var_read, Variable) and var_read.is_constant: - value = ConstantFolding(var_read.expression, var_read.type).result() - all_cst_used.append(ConstantValue(str(value), str(var_read.type))) + # In case of type conversion we use the destination type + if isinstance(ir, TypeConversion): + if isinstance(ir.type, TypeAlias): + value_type = ir.type.type + else: + value_type = ir.type + else: + value_type = var_read.type + value = ConstantFolding(var_read.expression, value_type).result() + all_cst_used.append(ConstantValue(str(value), str(value_type))) if isinstance(var_read, Constant): all_cst_used.append(ConstantValue(str(var_read.value), str(var_read.type))) if isinstance(var_read, StateVariable): diff --git a/slither/visitors/expression/constants_folding.py b/slither/visitors/expression/constants_folding.py index c8cfeb716..63828bdd8 100644 --- a/slither/visitors/expression/constants_folding.py +++ b/slither/visitors/expression/constants_folding.py @@ -78,6 +78,8 @@ def result(self) -> "Literal": def _post_identifier(self, expression: Identifier) -> None: from slither.core.declarations.solidity_variables import SolidityFunction from slither.core.declarations.enum import Enum + from slither.core.solidity_types.type_alias import TypeAlias + from slither.core.declarations.contract import Contract if isinstance(expression.value, Variable): if expression.value.is_constant: @@ -104,9 +106,11 @@ def _post_identifier(self, expression: Identifier) -> None: elif isinstance(expression.value, SolidityFunction): set_val(expression, expression.value) else: - # We don't want to raise an error for a direct access to an Enum as they can be converted to a constant value + # Enum: We don't want to raise an error for a direct access to an Enum as they can be converted to a constant value # We can't handle it here because we don't have the field accessed so we do it in _post_member_access - if not isinstance(expression.value, Enum): + # TypeAlias: Support when a .wrap() is done with a constant + # Contract: Support when a constatn is use from a different contract + if not isinstance(expression.value, (Enum, TypeAlias, Contract)): raise NotConstant # pylint: disable=too-many-branches,too-many-statements @@ -242,6 +246,7 @@ def _post_assignement_operation(self, expression: expressions.AssignmentOperatio def _post_call_expression(self, expression: expressions.CallExpression) -> None: from slither.core.declarations.solidity_variables import SolidityFunction from slither.core.declarations.enum import Enum + from slither.core.solidity_types import TypeAlias # pylint: disable=too-many-boolean-expressions if ( @@ -256,6 +261,17 @@ def _post_call_expression(self, expression: expressions.CallExpression) -> None: ): # Returning early to support type(ElemType).max/min or type(MyEnum).max/min return + if ( + isinstance(expression.called.expression, Identifier) + and isinstance(expression.called.expression.value, TypeAlias) + and isinstance(expression.called, MemberAccess) + and expression.called.member_name == "wrap" + and len(expression.arguments) == 1 + ): + # Handle constants in .wrap of user defined type + set_val(expression, get_val(expression.arguments[0])) + return + called = get_val(expression.called) args = [get_val(arg) for arg in expression.arguments] if called.name == "keccak256(bytes)": @@ -277,6 +293,7 @@ def _post_elementary_type_name_expression( def _post_index_access(self, expression: expressions.IndexAccess) -> None: raise NotConstant + # pylint: disable=too-many-locals def _post_member_access(self, expression: expressions.MemberAccess) -> None: from slither.core.declarations import ( SolidityFunction, @@ -285,7 +302,7 @@ def _post_member_access(self, expression: expressions.MemberAccess) -> None: EnumTopLevel, Enum, ) - from slither.core.solidity_types import UserDefinedType + from slither.core.solidity_types import UserDefinedType, TypeAlias # pylint: disable=too-many-nested-blocks if isinstance(expression.expression, CallExpression) and expression.member_name in [ @@ -334,6 +351,31 @@ def _post_member_access(self, expression: expressions.MemberAccess) -> None: # Handle direct access to enum field set_val(expression, expression.expression.value.values.index(expression.member_name)) return + elif isinstance(expression.expression, Identifier) and isinstance( + expression.expression.value, TypeAlias + ): + # User defined type .wrap call handled in _post_call_expression + return + elif ( + isinstance(expression.expression.value, Contract) + and expression.member_name in expression.expression.value.variables_as_dict + and expression.expression.value.variables_as_dict[expression.member_name].is_constant + ): + # Handles when a constant is accessed on another contract + variables = expression.expression.value.variables_as_dict + if isinstance(variables[expression.member_name].expression, MemberAccess): + self._post_member_access(variables[expression.member_name].expression) + set_val(expression, get_val(variables[expression.member_name].expression)) + return + + # If the variable is a Literal we convert its value to int + value = ( + convert_string_to_int(variables[expression.member_name].expression.converted_value) + if isinstance(variables[expression.member_name].expression, Literal) + else variables[expression.member_name].expression + ) + set_val(expression, value) + return raise NotConstant From 607b22a506fb875c6710fd337d61cc2d1f39e0c2 Mon Sep 17 00:00:00 2001 From: Simone Date: Mon, 7 Oct 2024 23:40:56 +0200 Subject: [PATCH 4/5] Handle UnaryOperation from another contract --- .../visitors/expression/constants_folding.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/slither/visitors/expression/constants_folding.py b/slither/visitors/expression/constants_folding.py index 63828bdd8..ddadb77a1 100644 --- a/slither/visitors/expression/constants_folding.py +++ b/slither/visitors/expression/constants_folding.py @@ -369,11 +369,19 @@ def _post_member_access(self, expression: expressions.MemberAccess) -> None: return # If the variable is a Literal we convert its value to int - value = ( - convert_string_to_int(variables[expression.member_name].expression.converted_value) - if isinstance(variables[expression.member_name].expression, Literal) - else variables[expression.member_name].expression - ) + if isinstance(variables[expression.member_name].expression, Literal): + value = convert_string_to_int( + variables[expression.member_name].expression.converted_value + ) + # If the variable is a UnaryOperation we need convert its value to int + # and replacing possible spaces + elif isinstance(variables[expression.member_name].expression, UnaryOperation): + value = convert_string_to_int( + str(variables[expression.member_name].expression).replace(" ", "") + ) + else: + value = variables[expression.member_name].expression + set_val(expression, value) return From f6b250944569a7e2eb6380f7593f89bc84bddc56 Mon Sep 17 00:00:00 2001 From: Feist Josselin Date: Thu, 17 Oct 2024 22:41:28 +0200 Subject: [PATCH 5/5] Add try/catch, add tests, refactor guidance/echidna to reduce pylint exceptions --- slither/printers/guidance/echidna.py | 110 +++++++++++------- tests/unit/slithir/test_constantfolding.py | 22 ++++ .../slithir/test_data/constantfolding.sol | 19 +++ 3 files changed, 106 insertions(+), 45 deletions(-) create mode 100644 tests/unit/slithir/test_constantfolding.py create mode 100644 tests/unit/slithir/test_data/constantfolding.sol diff --git a/slither/printers/guidance/echidna.py b/slither/printers/guidance/echidna.py index 057e78269..7e76cec0d 100644 --- a/slither/printers/guidance/echidna.py +++ b/slither/printers/guidance/echidna.py @@ -13,6 +13,7 @@ from slither.core.expressions import NewContract from slither.core.slither_core import SlitherCore from slither.core.solidity_types import TypeAlias +from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.variables.state_variable import StateVariable from slither.core.variables.variable import Variable from slither.printers.abstract_printer import AbstractPrinter @@ -179,7 +180,66 @@ class ConstantValue(NamedTuple): # pylint: disable=inherit-non-class,too-few-pu type: str -def _extract_constants_from_irs( # pylint: disable=too-many-branches,too-many-nested-blocks +def _extract_constant_from_read( + ir: Operation, + r: SourceMapping, + all_cst_used: List[ConstantValue], + all_cst_used_in_binary: Dict[str, List[ConstantValue]], + context_explored: Set[Node], +) -> None: + var_read = r.points_to_origin if isinstance(r, ReferenceVariable) else r + # Do not report struct_name in a.struct_name + if isinstance(ir, Member): + return + if isinstance(var_read, Variable) and var_read.is_constant: + # In case of type conversion we use the destination type + if isinstance(ir, TypeConversion): + if isinstance(ir.type, TypeAlias): + value_type = ir.type.type + else: + value_type = ir.type + else: + value_type = var_read.type + try: + value = ConstantFolding(var_read.expression, value_type).result() + all_cst_used.append(ConstantValue(str(value), str(value_type))) + except NotConstant: + pass + if isinstance(var_read, Constant): + all_cst_used.append(ConstantValue(str(var_read.value), str(var_read.type))) + if isinstance(var_read, StateVariable): + if var_read.node_initialization: + if var_read.node_initialization.irs: + if var_read.node_initialization in context_explored: + return + context_explored.add(var_read.node_initialization) + _extract_constants_from_irs( + var_read.node_initialization.irs, + all_cst_used, + all_cst_used_in_binary, + context_explored, + ) + + +def _extract_constant_from_binary( + ir: Binary, + all_cst_used: List[ConstantValue], + all_cst_used_in_binary: Dict[str, List[ConstantValue]], +): + for r in ir.read: + if isinstance(r, Constant): + all_cst_used_in_binary[str(ir.type)].append(ConstantValue(str(r.value), str(r.type))) + if isinstance(ir.variable_left, Constant) or isinstance(ir.variable_right, Constant): + if ir.lvalue: + try: + type_ = ir.lvalue.type + cst = ConstantFolding(ir.expression, type_).result() + all_cst_used.append(ConstantValue(str(cst.value), str(type_))) + except NotConstant: + pass + + +def _extract_constants_from_irs( irs: List[Operation], all_cst_used: List[ConstantValue], all_cst_used_in_binary: Dict[str, List[ConstantValue]], @@ -187,21 +247,7 @@ def _extract_constants_from_irs( # pylint: disable=too-many-branches,too-many-n ) -> None: for ir in irs: if isinstance(ir, Binary): - for r in ir.read: - if isinstance(r, Constant): - all_cst_used_in_binary[str(ir.type)].append( - ConstantValue(str(r.value), str(r.type)) - ) - if isinstance(ir.variable_left, Constant) or isinstance( - ir.variable_right, Constant - ): - if ir.lvalue: - try: - type_ = ir.lvalue.type - cst = ConstantFolding(ir.expression, type_).result() - all_cst_used.append(ConstantValue(str(cst.value), str(type_))) - except NotConstant: - pass + _extract_constant_from_binary(ir, all_cst_used, all_cst_used_in_binary) if isinstance(ir, TypeConversion): if isinstance(ir.variable, Constant): if isinstance(ir.type, TypeAlias): @@ -222,35 +268,9 @@ def _extract_constants_from_irs( # pylint: disable=too-many-branches,too-many-n except ValueError: # index could fail; should never happen in working solidity code pass for r in ir.read: - var_read = r.points_to_origin if isinstance(r, ReferenceVariable) else r - # Do not report struct_name in a.struct_name - if isinstance(ir, Member): - continue - if isinstance(var_read, Variable) and var_read.is_constant: - # In case of type conversion we use the destination type - if isinstance(ir, TypeConversion): - if isinstance(ir.type, TypeAlias): - value_type = ir.type.type - else: - value_type = ir.type - else: - value_type = var_read.type - value = ConstantFolding(var_read.expression, value_type).result() - all_cst_used.append(ConstantValue(str(value), str(value_type))) - if isinstance(var_read, Constant): - all_cst_used.append(ConstantValue(str(var_read.value), str(var_read.type))) - if isinstance(var_read, StateVariable): - if var_read.node_initialization: - if var_read.node_initialization.irs: - if var_read.node_initialization in context_explored: - continue - context_explored.add(var_read.node_initialization) - _extract_constants_from_irs( - var_read.node_initialization.irs, - all_cst_used, - all_cst_used_in_binary, - context_explored, - ) + _extract_constant_from_read( + ir, r, all_cst_used, all_cst_used_in_binary, context_explored + ) def _extract_constants( diff --git a/tests/unit/slithir/test_constantfolding.py b/tests/unit/slithir/test_constantfolding.py new file mode 100644 index 000000000..fcf00035b --- /dev/null +++ b/tests/unit/slithir/test_constantfolding.py @@ -0,0 +1,22 @@ +from pathlib import Path + +from slither import Slither +from slither.printers.guidance.echidna import _extract_constants, ConstantValue + +TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data" + + +def test_enum_max_min(solc_binary_path) -> None: + solc_path = solc_binary_path("0.8.19") + slither = Slither(Path(TEST_DATA_DIR, "constantfolding.sol").as_posix(), solc=solc_path) + + contracts = slither.get_contract_from_name("A") + + constants = _extract_constants(contracts)[0]["A"]["use()"] + + assert set(constants) == { + ConstantValue(value="2", type="uint256"), + ConstantValue(value="10", type="uint256"), + ConstantValue(value="100", type="uint256"), + ConstantValue(value="4294967295", type="uint32"), + } diff --git a/tests/unit/slithir/test_data/constantfolding.sol b/tests/unit/slithir/test_data/constantfolding.sol new file mode 100644 index 000000000..aef4a2427 --- /dev/null +++ b/tests/unit/slithir/test_data/constantfolding.sol @@ -0,0 +1,19 @@ +type MyType is uint256; + +contract A{ + + enum E{ + a,b,c + } + + + uint a = 10; + E b = type(E).max; + uint c = type(uint32).max; + MyType d = MyType.wrap(100); + + function use() public returns(uint){ + E e = b; + return a +c + MyType.unwrap(d); + } +} \ No newline at end of file