Skip to content

Commit

Permalink
Improve support for type .max/.min and Enums
Browse files Browse the repository at this point in the history
  • Loading branch information
smonicas committed Oct 7, 2024
1 parent afac8c4 commit 48d7d9b
Showing 1 changed file with 116 additions and 6 deletions.
122 changes: 116 additions & 6 deletions slither/visitors/expression/constants_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
TupleExpression,
TypeConversion,
CallExpression,
MemberAccess,
)
from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression
from slither.core.variables import Variable
from slither.utils.integer_conversion import convert_string_to_fraction, convert_string_to_int
from slither.visitors.expression.expression import ExpressionVisitor
Expand All @@ -27,7 +29,13 @@ class NotConstant(Exception):
KEY = "ConstantFolding"

CONSTANT_TYPES_OPERATIONS = Union[
Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion
Literal,
BinaryOperation,
UnaryOperation,
Identifier,
TupleExpression,
TypeConversion,
MemberAccess,
]


Expand Down Expand Up @@ -69,6 +77,7 @@ def result(self) -> "Literal":
# pylint: disable=import-outside-toplevel
def _post_identifier(self, expression: Identifier) -> None:
from slither.core.declarations.solidity_variables import SolidityFunction
from slither.core.declarations.enum import Enum

if isinstance(expression.value, Variable):
if expression.value.is_constant:
Expand All @@ -77,7 +86,14 @@ def _post_identifier(self, expression: Identifier) -> None:
# Everything outside of literal
if isinstance(
expr,
(BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion),
(
BinaryOperation,
UnaryOperation,
Identifier,
TupleExpression,
TypeConversion,
MemberAccess,
),
):
cf = ConstantFolding(expr, self._type)
expr = cf.result()
Expand All @@ -88,20 +104,39 @@ def _post_identifier(self, expression: Identifier) -> None:
elif isinstance(expression.value, SolidityFunction):
set_val(expression, expression.value)
else:
raise NotConstant
# We don't want to raise an error for a direct access to an Enum as they can be converted to a constant value
# We can't handle it here because we don't have the field accessed so we do it in _post_member_access
if not isinstance(expression.value, Enum):
raise NotConstant

# pylint: disable=too-many-branches,too-many-statements
def _post_binary_operation(self, expression: BinaryOperation) -> None:
expression_left = expression.expression_left
expression_right = expression.expression_right
if not isinstance(
expression_left,
(Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion),
(
Literal,
BinaryOperation,
UnaryOperation,
Identifier,
TupleExpression,
TypeConversion,
MemberAccess,
),
):
raise NotConstant
if not isinstance(
expression_right,
(Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion),
(
Literal,
BinaryOperation,
UnaryOperation,
Identifier,
TupleExpression,
TypeConversion,
MemberAccess,
),
):
raise NotConstant
left = get_val(expression_left)
Expand Down Expand Up @@ -205,6 +240,22 @@ def _post_assignement_operation(self, expression: expressions.AssignmentOperatio
raise NotConstant

def _post_call_expression(self, expression: expressions.CallExpression) -> None:
from slither.core.declarations.solidity_variables import SolidityFunction
from slither.core.declarations.enum import Enum

# pylint: disable=too-many-boolean-expressions
if (
isinstance(expression.called, Identifier)
and expression.called.value == SolidityFunction("type()")
and len(expression.arguments) == 1
and (
isinstance(expression.arguments[0], ElementaryTypeNameExpression)
or isinstance(expression.arguments[0], Identifier)
and isinstance(expression.arguments[0].value, Enum)
)
):
# Returning early to support type(ElemType).max/min or type(MyEnum).max/min
return
called = get_val(expression.called)
args = [get_val(arg) for arg in expression.arguments]
if called.name == "keccak256(bytes)":
Expand All @@ -220,12 +271,70 @@ def _post_conditional_expression(self, expression: expressions.ConditionalExpres
def _post_elementary_type_name_expression(
self, expression: expressions.ElementaryTypeNameExpression
) -> None:
raise NotConstant
# We don't have to raise an exception to support type(uint112).max or similar
pass

def _post_index_access(self, expression: expressions.IndexAccess) -> None:
raise NotConstant

def _post_member_access(self, expression: expressions.MemberAccess) -> None:
from slither.core.declarations import (
SolidityFunction,
Contract,
EnumContract,
EnumTopLevel,
Enum,
)
from slither.core.solidity_types import UserDefinedType

# pylint: disable=too-many-nested-blocks
if isinstance(expression.expression, CallExpression) and expression.member_name in [
"min",
"max",
]:
if isinstance(expression.expression.called, Identifier):
if expression.expression.called.value == SolidityFunction("type()"):
assert len(expression.expression.arguments) == 1
type_expression_found = expression.expression.arguments[0]
type_found: Union[ElementaryType, UserDefinedType]
if isinstance(type_expression_found, ElementaryTypeNameExpression):
type_expression_found_type = type_expression_found.type
assert isinstance(type_expression_found_type, ElementaryType)
type_found = type_expression_found_type
value = (
type_found.max if expression.member_name == "max" else type_found.min
)
set_val(expression, value)
return
# type(enum).max/min
# Case when enum is in another contract e.g. type(C.E).max
if isinstance(type_expression_found, MemberAccess):
contract = type_expression_found.expression.value
assert isinstance(contract, Contract)
for enum in contract.enums:
if enum.name == type_expression_found.member_name:
type_found_in_expression = enum
type_found = UserDefinedType(enum)
break
else:
assert isinstance(type_expression_found, Identifier)
type_found_in_expression = type_expression_found.value
assert isinstance(type_found_in_expression, (EnumContract, EnumTopLevel))
type_found = UserDefinedType(type_found_in_expression)
value = (
type_found_in_expression.max
if expression.member_name == "max"
else type_found_in_expression.min
)
set_val(expression, value)
return
elif isinstance(expression.expression, Identifier) and isinstance(
expression.expression.value, Enum
):
# Handle direct access to enum field
set_val(expression, expression.expression.value.values.index(expression.member_name))
return

raise NotConstant

def _post_new_array(self, expression: expressions.NewArray) -> None:
Expand Down Expand Up @@ -272,6 +381,7 @@ def _post_type_conversion(self, expression: expressions.TypeConversion) -> None:
TupleExpression,
TypeConversion,
CallExpression,
MemberAccess,
),
):
raise NotConstant
Expand Down

0 comments on commit 48d7d9b

Please sign in to comment.