diff --git a/wake/ir/types.py b/wake/ir/types.py index 416d81321..c1e1c6f19 100644 --- a/wake/ir/types.py +++ b/wake/ir/types.py @@ -34,56 +34,102 @@ def from_type_identifier( type_identifier: StringReader, reference_resolver: ReferenceResolver, cu_hash: bytes, - ) -> typ.Optional["TypeAbc"]: + ) -> typ.Optional[TypeAbc]: if type_identifier.startswith("t_address"): - return Address(type_identifier) + return Address.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) elif type_identifier.startswith("t_bool"): - return Bool(type_identifier) + return Bool.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) elif type_identifier.startswith("t_int"): - return Int(type_identifier) + return Int.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) elif type_identifier.startswith("t_uint"): - return UInt(type_identifier) + return UInt.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) elif type_identifier.startswith("t_stringliteral_"): - return StringLiteral(type_identifier) + return StringLiteral.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) elif type_identifier.startswith("t_string_"): - return String(type_identifier) + return String.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) elif type_identifier.startswith("t_bytes_"): - return Bytes(type_identifier) + return Bytes.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) # must go after t_bytes_ !! elif type_identifier.startswith("t_bytes"): - return FixedBytes(type_identifier) + return FixedBytes.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) elif type_identifier.startswith("t_function"): - return Function(type_identifier, reference_resolver, cu_hash) + return Function.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) elif type_identifier.startswith("t_tuple"): - return Tuple(type_identifier, reference_resolver, cu_hash) + return Tuple.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) elif type_identifier.startswith("t_type"): - return Type(type_identifier, reference_resolver, cu_hash) + return Type.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) elif type_identifier.startswith("t_rational"): - return Rational(type_identifier) + return Rational.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) elif type_identifier.startswith("t_modifier"): - return Modifier(type_identifier, reference_resolver, cu_hash) + return Modifier.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) elif type_identifier.startswith("t_array"): - return Array(type_identifier, reference_resolver, cu_hash) + return Array.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) elif type_identifier.startswith("t_mapping"): - return Mapping(type_identifier, reference_resolver, cu_hash) + return Mapping.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) elif type_identifier.startswith("t_contract") or type_identifier.startswith( "t_super" ): - return Contract(type_identifier, reference_resolver, cu_hash) + return Contract.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) elif type_identifier.startswith("t_struct"): - return Struct(type_identifier, reference_resolver, cu_hash) + return Struct.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) elif type_identifier.startswith("t_enum"): - return Enum(type_identifier, reference_resolver, cu_hash) + return Enum.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) elif type_identifier.startswith("t_magic"): - return Magic(type_identifier, reference_resolver, cu_hash) + return Magic.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) elif type_identifier.startswith("t_userDefinedValueType"): - return UserDefinedValueType(type_identifier, reference_resolver, cu_hash) + return UserDefinedValueType.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) elif type_identifier.startswith("t_module"): - return Module(type_identifier, reference_resolver, cu_hash) + return Module.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) elif type_identifier.startswith("t_fixed"): - return Fixed(type_identifier) + return Fixed.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) elif type_identifier.startswith("t_ufixed"): - return UFixed(type_identifier) + return UFixed.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) else: return None @@ -154,20 +200,29 @@ class Address(TypeAbc): _is_payable: bool - def __init__(self, type_identifier: StringReader): - type_identifier.read("t_address") - - if type_identifier.startswith("_payable"): - type_identifier.read("_payable") - self._is_payable = True - else: - self._is_payable = False + def __init__(self, is_payable: bool): + self._is_payable = is_payable def __eq__(self, other: object) -> bool: if not isinstance(other, Address): return False return self._is_payable == other._is_payable + @classmethod + def from_type_identifier( + cls, + type_identifier: StringReader, + reference_resolver: ReferenceResolver, + cu_hash: bytes, + ) -> Address: + type_identifier.read("t_address") + + if type_identifier.startswith("_payable"): + type_identifier.read("_payable") + return Address(True) + else: + return Address(False) + @property def abi_type(self) -> str: return "address" @@ -186,12 +241,19 @@ class Bool(TypeAbc): Boolean type. """ - def __init__(self, type_identifier: StringReader): - type_identifier.read("t_bool") - def __eq__(self, other: object) -> bool: return isinstance(other, Bool) + @classmethod + def from_type_identifier( + cls, + type_identifier: StringReader, + reference_resolver: ReferenceResolver, + cu_hash: bytes, + ) -> Bool: + type_identifier.read("t_bool") + return Bool() + @property def abi_type(self) -> str: return "bool" @@ -220,19 +282,28 @@ class Int(IntAbc): Signed integer type. """ - def __init__(self, type_identifier: StringReader): - type_identifier.read("t_int") - match = NUMBER_RE.match(type_identifier.data) - assert match is not None - number = match.group("number") - type_identifier.read(number) - self._bits_count = int(number) + def __init__(self, bits_count: int): + self._bits_count = bits_count def __eq__(self, other: object) -> bool: if not isinstance(other, Int): return False return self._bits_count == other._bits_count + @classmethod + def from_type_identifier( + cls, + type_identifier: StringReader, + reference_resolver: ReferenceResolver, + cu_hash: bytes, + ) -> Int: + type_identifier.read("t_int") + match = NUMBER_RE.match(type_identifier.data) + assert match is not None + number = match.group("number") + type_identifier.read(number) + return Int(int(number)) + @property def abi_type(self) -> str: return f"int{self._bits_count}" @@ -243,19 +314,28 @@ class UInt(IntAbc): Unsigned integer type. """ - def __init__(self, type_identifier: StringReader): - type_identifier.read("t_uint") - match = NUMBER_RE.match(type_identifier.data) - assert match is not None - number = match.group("number") - type_identifier.read(number) - self._bits_count = int(number) + def __init__(self, bits_count: int): + self._bits_count = bits_count def __eq__(self, other: object) -> bool: if not isinstance(other, UInt): return False return self._bits_count == other._bits_count + @classmethod + def from_type_identifier( + cls, + type_identifier: StringReader, + reference_resolver: ReferenceResolver, + cu_hash: bytes, + ) -> UInt: + type_identifier.read("t_uint") + match = NUMBER_RE.match(type_identifier.data) + assert match is not None + number = match.group("number") + type_identifier.read(number) + return UInt(int(number)) + @property def abi_type(self) -> str: return f"uint{self._bits_count}" @@ -295,13 +375,30 @@ class Fixed(FixedAbc): Currently not fully implemented in Solidity. """ - def __init__(self, type_identifier: StringReader): + def __init__(self, total_bits: int, fractional_digits: int): + self._total_bits = total_bits + self._fractional_digits = fractional_digits + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Fixed): + return False + return ( + self._total_bits == other._total_bits + and self._fractional_digits == other._fractional_digits + ) + + @classmethod + def from_type_identifier( + cls, + type_identifier: StringReader, + reference_resolver: ReferenceResolver, + cu_hash: bytes, + ) -> Fixed: type_identifier.read("t_fixed") match = NUMBER_RE.match(type_identifier.data) assert match is not None total_bits = match.group("number") type_identifier.read(total_bits) - self._total_bits = int(total_bits) type_identifier.read("x") @@ -309,19 +406,11 @@ def __init__(self, type_identifier: StringReader): assert match is not None fractional_digits = match.group("number") type_identifier.read(fractional_digits) - self._fractional_digits = int(fractional_digits) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Fixed): - return False - return ( - self._total_bits == other._total_bits - and self._fractional_digits == other._fractional_digits - ) + return Fixed(int(total_bits), int(fractional_digits)) @property def abi_type(self) -> str: - raise NotImplementedError + return f"fixed{self._total_bits}x{self._fractional_digits}" class UFixed(FixedAbc): @@ -331,13 +420,30 @@ class UFixed(FixedAbc): Currently not fully implemented in Solidity. """ - def __init__(self, type_identifier: StringReader): + def __init__(self, total_bits: int, fractional_digits: int): + self._total_bits = total_bits + self._fractional_digits = fractional_digits + + def __eq__(self, other: object) -> bool: + if not isinstance(other, UFixed): + return False + return ( + self._total_bits == other._total_bits + and self._fractional_digits == other._fractional_digits + ) + + @classmethod + def from_type_identifier( + cls, + type_identifier: StringReader, + reference_resolver: ReferenceResolver, + cu_hash: bytes, + ) -> UFixed: type_identifier.read("t_ufixed") match = NUMBER_RE.match(type_identifier.data) assert match is not None total_bits = match.group("number") type_identifier.read(total_bits) - self._total_bits = int(total_bits) type_identifier.read("x") @@ -345,19 +451,11 @@ def __init__(self, type_identifier: StringReader): assert match is not None fractional_digits = match.group("number") type_identifier.read(fractional_digits) - self._fractional_digits = int(fractional_digits) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, UFixed): - return False - return ( - self._total_bits == other._total_bits - and self._fractional_digits == other._fractional_digits - ) + return UFixed(int(total_bits), int(fractional_digits)) @property def abi_type(self) -> str: - raise NotImplementedError + return f"ufixed{self._total_bits}x{self._fractional_digits}" class StringLiteral(TypeAbc): @@ -377,19 +475,28 @@ class StringLiteral(TypeAbc): _keccak256_hash: bytes - def __init__(self, type_identifier: StringReader): - type_identifier.read("t_stringliteral_") - match = HEX_RE.match(type_identifier.data) - assert match is not None - hex = match.group("hex") - type_identifier.read(hex) - self._keccak256_hash = bytes.fromhex(hex) + def __init__(self, keccak256_hash: bytes): + self._keccak256_hash = keccak256_hash def __eq__(self, other: object) -> bool: if not isinstance(other, StringLiteral): return False return self._keccak256_hash == other._keccak256_hash + @classmethod + def from_type_identifier( + cls, + type_identifier: StringReader, + reference_resolver: ReferenceResolver, + cu_hash: bytes, + ) -> StringLiteral: + type_identifier.read("t_stringliteral_") + match = HEX_RE.match(type_identifier.data) + assert match is not None + hex = match.group("hex") + type_identifier.read(hex) + return StringLiteral(bytes.fromhex(hex)) + @property def abi_type(self) -> str: raise NotImplementedError @@ -412,40 +519,53 @@ class String(TypeAbc): _is_pointer: bool _is_slice: bool - def __init__(self, type_identifier: StringReader): + def __init__(self, data_location: DataLocation, is_pointer: bool, is_slice: bool): + self._data_location = data_location + self._is_pointer = is_pointer + self._is_slice = is_slice + + def __eq__(self, other: object) -> bool: + if not isinstance(other, String): + return False + return ( + self._data_location == other._data_location + and self._is_pointer == other._is_pointer + and self._is_slice == other._is_slice + ) + + @classmethod + def from_type_identifier( + cls, + type_identifier: StringReader, + reference_resolver: ReferenceResolver, + cu_hash: bytes, + ) -> String: type_identifier.read("t_string") if type_identifier.startswith("_storage"): - self._data_location = DataLocation.STORAGE + data_location = DataLocation.STORAGE type_identifier.read("_storage") elif type_identifier.startswith("_memory"): - self._data_location = DataLocation.MEMORY + data_location = DataLocation.MEMORY type_identifier.read("_memory") elif type_identifier.startswith("_calldata"): - self._data_location = DataLocation.CALLDATA + data_location = DataLocation.CALLDATA type_identifier.read("_calldata") else: assert False, f"Unexpected string type data location {type_identifier}" if type_identifier.startswith("_ptr"): - self._is_pointer = True + is_pointer = True type_identifier.read("_ptr") else: - self._is_pointer = False + is_pointer = False if type_identifier.startswith("_slice"): - self._is_slice = True + is_slice = True type_identifier.read("_slice") else: - self._is_slice = False + is_slice = False - def __eq__(self, other: object) -> bool: - if not isinstance(other, String): - return False - return ( - self._data_location == other._data_location - and self._is_pointer == other._is_pointer - and self._is_slice == other._is_slice - ) + return String(data_location, is_pointer, is_slice) @property def abi_type(self) -> str: @@ -502,40 +622,53 @@ class Bytes(TypeAbc): _is_pointer: bool _is_slice: bool - def __init__(self, type_identifier: StringReader): + def __init__(self, data_location: DataLocation, is_pointer: bool, is_slice: bool): + self._data_location = data_location + self._is_pointer = is_pointer + self._is_slice = is_slice + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Bytes): + return False + return ( + self._data_location == other._data_location + and self._is_pointer == other._is_pointer + and self._is_slice == other._is_slice + ) + + @classmethod + def from_type_identifier( + cls, + type_identifier: StringReader, + reference_resolver: ReferenceResolver, + cu_hash: bytes, + ) -> Bytes: type_identifier.read("t_bytes") if type_identifier.startswith("_storage"): - self._data_location = DataLocation.STORAGE + data_location = DataLocation.STORAGE type_identifier.read("_storage") elif type_identifier.startswith("_memory"): - self._data_location = DataLocation.MEMORY + data_location = DataLocation.MEMORY type_identifier.read("_memory") elif type_identifier.startswith("_calldata"): - self._data_location = DataLocation.CALLDATA + data_location = DataLocation.CALLDATA type_identifier.read("_calldata") else: assert False, f"Unexpected string type data location {type_identifier}" if type_identifier.startswith("_ptr"): - self._is_pointer = True + is_pointer = True type_identifier.read("_ptr") else: - self._is_pointer = False + is_pointer = False if type_identifier.startswith("_slice"): - self._is_slice = True + is_slice = True type_identifier.read("_slice") else: - self._is_slice = False + is_slice = False - def __eq__(self, other: object) -> bool: - if not isinstance(other, Bytes): - return False - return ( - self._data_location == other._data_location - and self._is_pointer == other._is_pointer - and self._is_slice == other._is_slice - ) + return Bytes(data_location, is_pointer, is_slice) @property def abi_type(self) -> str: @@ -590,19 +723,28 @@ class FixedBytes(TypeAbc): _bytes_count: int - def __init__(self, type_identifier: StringReader): - type_identifier.read("t_bytes") - match = NUMBER_RE.match(type_identifier.data) - assert match is not None - number = match.group("number") - type_identifier.read(number) - self._bytes_count = int(number) + def __init__(self, bytes_count: int): + self._bytes_count = bytes_count def __eq__(self, other: object) -> bool: if not isinstance(other, FixedBytes): return False return self._bytes_count == other._bytes_count + @classmethod + def from_type_identifier( + cls, + type_identifier: StringReader, + reference_resolver: ReferenceResolver, + cu_hash: bytes, + ) -> FixedBytes: + type_identifier.read("t_bytes") + match = NUMBER_RE.match(type_identifier.data) + assert match is not None + number = match.group("number") + type_identifier.read(number) + return FixedBytes(int(number)) + @property def abi_type(self) -> str: return f"bytes{self._bytes_count}" @@ -647,10 +789,48 @@ class Function(TypeAbc): def __init__( self, + kind: FunctionTypeKind, + state_mutability: StateMutability, + parameters: typ.Iterable[TypeAbc], + return_parameters: typ.Iterable[TypeAbc], + gas_set: bool, + value_set: bool, + salt_set: bool, + attached_to: typ.Optional[typ.Iterable[TypeAbc]], + ): + self._kind = kind + self._state_mutability = state_mutability + self._parameters = tuple(parameters) + self._return_parameters = tuple(return_parameters) + self._gas_set = gas_set + self._value_set = value_set + self._salt_set = salt_set + if attached_to is not None: + self._attached_to = tuple(attached_to) + else: + self._attached_to = None + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Function): + return False + return ( + self._kind == other._kind + and self._state_mutability == other._state_mutability + and self._parameters == other._parameters + and self._return_parameters == other._return_parameters + and self._gas_set == other._gas_set + and self._value_set == other._value_set + and self._salt_set == other._salt_set + and self._attached_to == other._attached_to + ) + + @classmethod + def from_type_identifier( + cls, type_identifier: StringReader, reference_resolver: ReferenceResolver, cu_hash: bytes, - ): + ) -> Function: type_identifier.read("t_function_") matched = [] @@ -658,78 +838,72 @@ def __init__( if type_identifier.startswith(kind): matched.append(kind) assert len(matched) >= 1, f"Unexpected function kind {type_identifier}" - self._kind = FunctionTypeKind(max(matched, key=len)) - type_identifier.read(self._kind) + kind = FunctionTypeKind(max(matched, key=len)) + type_identifier.read(kind) if type_identifier.startswith("_payable"): - self._state_mutability = StateMutability.PAYABLE + state_mutability = StateMutability.PAYABLE type_identifier.read("_payable") elif type_identifier.startswith("_pure"): - self._state_mutability = StateMutability.PURE + state_mutability = StateMutability.PURE type_identifier.read("_pure") elif type_identifier.startswith("_nonpayable"): - self._state_mutability = StateMutability.NONPAYABLE + state_mutability = StateMutability.NONPAYABLE type_identifier.read("_nonpayable") elif type_identifier.startswith("_view"): - self._state_mutability = StateMutability.VIEW + state_mutability = StateMutability.VIEW type_identifier.read("_view") else: assert False, f"Unexpected function state mutability {type_identifier}" parameters = _parse_list(type_identifier, reference_resolver, cu_hash) assert not any(param is None for param in parameters) - self._parameters = parameters # type: ignore type_identifier.read("returns") return_parameters = _parse_list(type_identifier, reference_resolver, cu_hash) assert not any(param is None for param in return_parameters) - self._return_parameters = return_parameters # type: ignore if type_identifier.startswith("gas"): - self._gas_set = True + gas_set = True type_identifier.read("gas") else: - self._gas_set = False + gas_set = False if type_identifier.startswith("value"): - self._value_set = True + value_set = True type_identifier.read("value") else: - self._value_set = False + value_set = False if type_identifier.startswith("salt"): - self._salt_set = True + salt_set = True type_identifier.read("salt") else: - self._salt_set = False + salt_set = False if type_identifier.startswith("bound_to"): type_identifier.read("bound_to") bound_to = _parse_list(type_identifier, reference_resolver, cu_hash) assert not any(param is None for param in bound_to) - self._attached_to = bound_to # type: ignore + attached_to = bound_to elif type_identifier.startswith( "attached_to" ): # bound_to was renamed to attached_to in 0.8.18 type_identifier.read("attached_to") attached_to = _parse_list(type_identifier, reference_resolver, cu_hash) assert not any(param is None for param in attached_to) - self._attached_to = attached_to # type: ignore else: - self._attached_to = None - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Function): - return False - return ( - self._kind == other._kind - and self._state_mutability == other._state_mutability - and self._parameters == other._parameters - and self._return_parameters == other._return_parameters - and self._gas_set == other._gas_set - and self._value_set == other._value_set - and self._salt_set == other._salt_set - and self._attached_to == other._attached_to + attached_to = None + + return Function( + kind, + state_mutability, + parameters, # pyright: ignore reportArgumentType + return_parameters, # pyright: ignore reportArgumentType + gas_set, + value_set, + salt_set, + attached_to, # pyright: ignore reportArgumentType ) @property @@ -858,20 +1032,25 @@ class Tuple(TypeAbc): _components: typ.Tuple[typ.Optional[TypeAbc], ...] - def __init__( - self, - type_identifier: StringReader, - reference_resolver: ReferenceResolver, - cu_hash: bytes, - ): - type_identifier.read("t_tuple") - self._components = _parse_list(type_identifier, reference_resolver, cu_hash) + def __init__(self, components: typ.Iterable[typ.Optional[TypeAbc]]): + self._components = tuple(components) def __eq__(self, other: object) -> bool: if not isinstance(other, Tuple): return False return self._components == other._components + @classmethod + def from_type_identifier( + cls, + type_identifier: StringReader, + reference_resolver: ReferenceResolver, + cu_hash: bytes, + ) -> Tuple: + type_identifier.read("t_tuple") + components = _parse_list(type_identifier, reference_resolver, cu_hash) + return Tuple(components) + @property def abi_type(self) -> str: if any(component is None for component in self._components): @@ -909,21 +1088,25 @@ class Type(TypeAbc): _actual_type: TypeAbc - def __init__( - self, + def __init__(self, actual_type: TypeAbc): + self._actual_type = actual_type + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Type): + return False + return self._actual_type == other._actual_type + + @classmethod + def from_type_identifier( + cls, type_identifier: StringReader, reference_resolver: ReferenceResolver, cu_hash: bytes, - ): + ) -> Type: type_identifier.read("t_type") actual_type = _parse_list(type_identifier, reference_resolver, cu_hash) assert len(actual_type) == 1 and actual_type[0] is not None - self._actual_type = actual_type[0] - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Type): - return False - return self._actual_type == other._actual_type + return Type(actual_type[0]) @property def abi_type(self) -> str: @@ -974,20 +1157,38 @@ class Rational(TypeAbc): _numerator: int _denominator: int - def __init__(self, type_identifier: StringReader): + def __init__(self, numerator: int, denominator: int): + self._numerator = numerator + self._denominator = denominator + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Rational): + return False + return ( + self._numerator == other._numerator + and self._denominator == other._denominator + ) + + @classmethod + def from_type_identifier( + cls, + type_identifier: StringReader, + reference_resolver: ReferenceResolver, + cu_hash: bytes, + ) -> Rational: type_identifier.read("t_rational_") if type_identifier.startswith("minus_"): type_identifier.read("minus_") - self._numerator = -1 + numerator = -1 else: - self._numerator = 1 + numerator = 1 match = NUMBER_RE.match(type_identifier.data) assert match is not None, f"{type_identifier} is not a valid rational" number = match.group("number") type_identifier.read(number) - self._numerator *= int(number) + numerator *= int(number) type_identifier.read("_by_") @@ -995,15 +1196,8 @@ def __init__(self, type_identifier: StringReader): assert match is not None, f"{type_identifier} is not a valid rational" number = match.group("number") type_identifier.read(number) - self._denominator = int(number) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Rational): - return False - return ( - self._numerator == other._numerator - and self._denominator == other._denominator - ) + denominator = int(number) + return Rational(numerator, denominator) @property def abi_type(self) -> str: @@ -1035,21 +1229,25 @@ class Modifier(TypeAbc): _parameters: typ.Tuple[TypeAbc, ...] - def __init__( - self, + def __init__(self, parameters: typ.Iterable[TypeAbc]): + self._parameters = tuple(parameters) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Modifier): + return False + return self._parameters == other._parameters + + @classmethod + def from_type_identifier( + cls, type_identifier: StringReader, reference_resolver: ReferenceResolver, cu_hash: bytes, - ): + ) -> Modifier: type_identifier.read("t_modifier") parameters = _parse_list(type_identifier, reference_resolver, cu_hash) assert not any(param is None for param in parameters) - self._parameters = parameters # type: ignore - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Modifier): - return False - return self._parameters == other._parameters + return Modifier(parameters) # pyright: ignore reportArgumentType @property def abi_type(self) -> str: @@ -1077,60 +1275,77 @@ class Array(TypeAbc): def __init__( self, + base_type: TypeAbc, + length: typ.Optional[int], + data_location: DataLocation, + is_pointer: bool, + is_slice: bool, + ): + self._base_type = base_type + self._length = length + self._data_location = data_location + self._is_pointer = is_pointer + self._is_slice = is_slice + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Array): + return False + return ( + self._base_type == other._base_type + and self._length == other._length + and self._data_location == other._data_location + and self._is_pointer == other._is_pointer + and self._is_slice == other._is_slice + ) + + @classmethod + def from_type_identifier( + cls, type_identifier: StringReader, reference_resolver: ReferenceResolver, cu_hash: bytes, - ): + ) -> Array: type_identifier.read("t_array") base_type = _parse_list(type_identifier, reference_resolver, cu_hash) assert ( len(base_type) == 1 and base_type[0] is not None ), f"Unexpected array base type {type_identifier}" - self._base_type = base_type[0] + base_type = base_type[0] if type_identifier.startswith("dyn"): - self._length = None + length = None type_identifier.read("dyn") else: match = NUMBER_RE.match(type_identifier.data) assert match is not None, f"{type_identifier} is not a valid array length" - self._length = int(match.group("number")) + length = int(match.group("number")) type_identifier.read(match.group("number")) if type_identifier.startswith("_storage"): - self._data_location = DataLocation.STORAGE + data_location = DataLocation.STORAGE type_identifier.read("_storage") elif type_identifier.startswith("_memory"): - self._data_location = DataLocation.MEMORY + data_location = DataLocation.MEMORY type_identifier.read("_memory") elif type_identifier.startswith("_calldata"): - self._data_location = DataLocation.CALLDATA + data_location = DataLocation.CALLDATA type_identifier.read("_calldata") else: assert False, f"Unexpected array type data location {type_identifier}" if type_identifier.startswith("_ptr"): - self._is_pointer = True + is_pointer = True type_identifier.read("_ptr") else: - self._is_pointer = False + is_pointer = False if type_identifier.startswith("_slice"): - self._is_slice = True + is_slice = True type_identifier.read("_slice") else: - self._is_slice = False + is_slice = False - def __eq__(self, other: object) -> bool: - if not isinstance(other, Array): - return False - return ( - self._base_type == other._base_type - and self._length == other._length - and self._data_location == other._data_location - and self._is_pointer == other._is_pointer - and self._is_slice == other._is_slice - ) + return Array(base_type, length, data_location, is_pointer, is_slice) @property def abi_type(self) -> str: @@ -1203,26 +1418,30 @@ class Mapping(TypeAbc): _key_type: TypeAbc _value_type: TypeAbc - def __init__( - self, + def __init__(self, key_type: TypeAbc, value_type: TypeAbc): + self._key_type = key_type + self._value_type = value_type + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Mapping): + return False + return ( + self._key_type == other._key_type and self._value_type == other._value_type + ) + + @classmethod + def from_type_identifier( + cls, type_identifier: StringReader, reference_resolver: ReferenceResolver, cu_hash: bytes, - ): + ) -> Mapping: type_identifier.read("t_mapping") key_value = _parse_list(type_identifier, reference_resolver, cu_hash) assert len(key_value) == 2, f"{type_identifier} is not a valid mapping" assert key_value[0] is not None, f"{type_identifier} is not a valid mapping" assert key_value[1] is not None, f"{type_identifier} is not a valid mapping" - self._key_type = key_value[0] - self._value_type = key_value[1] - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Mapping): - return False - return ( - self._key_type == other._key_type and self._value_type == other._value_type - ) + return Mapping(key_value[0], key_value[1]) @property def abi_type(self) -> str: @@ -1254,25 +1473,15 @@ class Contract(TypeAbc): def __init__( self, - type_identifier: StringReader, + is_super: bool, + name: str, + ast_id: AstNodeId, reference_resolver: ReferenceResolver, cu_hash: bytes, ): - if type_identifier.startswith("t_contract"): - self._is_super = False - type_identifier.read("t_contract") - elif type_identifier.startswith("t_super"): - self._is_super = True - type_identifier.read("t_super") - else: - assert False, f"Unexpected contract type {type_identifier}" - self._name = _parse_user_identifier(type_identifier) - - match = NUMBER_RE.match(type_identifier.data) - assert match is not None, f"{type_identifier} is not a valid contract" - self._ast_id = AstNodeId(int(match.group("number"))) - type_identifier.read(match.group("number")) - + self._is_super = is_super + self._name = name + self._ast_id = ast_id self._reference_resolver = reference_resolver self._cu_hash = cu_hash @@ -1285,6 +1494,30 @@ def __eq__(self, other: object) -> bool: and self.ir_node == other.ir_node ) + @classmethod + def from_type_identifier( + cls, + type_identifier: StringReader, + reference_resolver: ReferenceResolver, + cu_hash: bytes, + ) -> Contract: + if type_identifier.startswith("t_contract"): + is_super = False + type_identifier.read("t_contract") + elif type_identifier.startswith("t_super"): + is_super = True + type_identifier.read("t_super") + else: + assert False, f"Unexpected contract type {type_identifier}" + name = _parse_user_identifier(type_identifier) + + match = NUMBER_RE.match(type_identifier.data) + assert match is not None, f"{type_identifier} is not a valid contract" + ast_id = AstNodeId(int(match.group("number"))) + type_identifier.read(match.group("number")) + + return Contract(is_super, name, ast_id, reference_resolver, cu_hash) + @property def abi_type(self) -> str: return "address" @@ -1349,47 +1582,65 @@ class Struct(TypeAbc): def __init__( self, - type_identifier: StringReader, + name: str, + data_location: DataLocation, + is_pointer: bool, + ast_id: AstNodeId, reference_resolver: ReferenceResolver, cu_hash: bytes, ): + self._name = name + self._data_location = data_location + self._is_pointer = is_pointer + self._ast_id = ast_id + self._reference_resolver = reference_resolver + self._cu_hash = cu_hash + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Struct): + return False + return ( + self._name == other._name + and self._data_location == other._data_location + and self._is_pointer == other._is_pointer + and self.ir_node == other.ir_node + ) + + @classmethod + def from_type_identifier( + cls, + type_identifier: StringReader, + reference_resolver: ReferenceResolver, + cu_hash: bytes, + ) -> Struct: type_identifier.read("t_struct") - self._name = _parse_user_identifier(type_identifier) + name = _parse_user_identifier(type_identifier) match = NUMBER_RE.match(type_identifier.data) assert match is not None, f"{type_identifier} is not a valid struct" - self._ast_id = AstNodeId(int(match.group("number"))) + ast_id = AstNodeId(int(match.group("number"))) type_identifier.read(match.group("number")) if type_identifier.startswith("_storage"): - self._data_location = DataLocation.STORAGE + data_location = DataLocation.STORAGE type_identifier.read("_storage") elif type_identifier.startswith("_memory"): - self._data_location = DataLocation.MEMORY + data_location = DataLocation.MEMORY type_identifier.read("_memory") elif type_identifier.startswith("_calldata"): - self._data_location = DataLocation.CALLDATA + data_location = DataLocation.CALLDATA type_identifier.read("_calldata") else: assert False, f"Unexpected array type data location {type_identifier}" if type_identifier.startswith("_ptr"): - self._is_pointer = True + is_pointer = True type_identifier.read("_ptr") else: - self._is_pointer = False - - self._reference_resolver = reference_resolver - self._cu_hash = cu_hash + is_pointer = False - def __eq__(self, other: object) -> bool: - if not isinstance(other, Struct): - return False - return ( - self._name == other._name - and self._data_location == other._data_location - and self._is_pointer == other._is_pointer - and self.ir_node == other.ir_node + return Struct( + name, data_location, is_pointer, ast_id, reference_resolver, cu_hash ) @property @@ -1460,18 +1711,13 @@ class Enum(TypeAbc): def __init__( self, - type_identifier: StringReader, + name: str, + ast_id: AstNodeId, reference_resolver: ReferenceResolver, cu_hash: bytes, ): - type_identifier.read("t_enum") - self._name = _parse_user_identifier(type_identifier) - - match = NUMBER_RE.match(type_identifier.data) - assert match is not None, f"{type_identifier} is not a valid enum" - self._ast_id = AstNodeId(int(match.group("number"))) - type_identifier.read(match.group("number")) - + self._name = name + self._ast_id = ast_id self._reference_resolver = reference_resolver self._cu_hash = cu_hash @@ -1480,6 +1726,23 @@ def __eq__(self, other: object) -> bool: return False return self._name == other._name and self.ir_node == other.ir_node + @classmethod + def from_type_identifier( + cls, + type_identifier: StringReader, + reference_resolver: ReferenceResolver, + cu_hash: bytes, + ) -> Enum: + type_identifier.read("t_enum") + name = _parse_user_identifier(type_identifier) + + match = NUMBER_RE.match(type_identifier.data) + assert match is not None, f"{type_identifier} is not a valid enum" + ast_id = AstNodeId(int(match.group("number"))) + type_identifier.read(match.group("number")) + + return Enum(name, ast_id, reference_resolver, cu_hash) + @property def abi_type(self) -> str: return "uint8" @@ -1513,32 +1776,9 @@ class Magic(TypeAbc): _kind: MagicTypeKind _meta_argument_type: typ.Optional[TypeAbc] - def __init__( - self, - type_identifier: StringReader, - reference_resolver: ReferenceResolver, - cu_hash: bytes, - ): - type_identifier.read("t_magic_") - - matched = False - for kind in MagicTypeKind: - if type_identifier.startswith(kind): - self._kind = MagicTypeKind(kind) - type_identifier.read(kind) - matched = True - break - assert matched, f"Unexpected magic kind {type_identifier}" - - if self._kind == MagicTypeKind.META_TYPE: - type_identifier.read("_") - meta_argument_type = TypeAbc.from_type_identifier( - type_identifier, reference_resolver, cu_hash - ) - assert meta_argument_type is not None - self._meta_argument_type = meta_argument_type - else: - self._meta_argument_type = None + def __init__(self, kind: MagicTypeKind, meta_argument_type: typ.Optional[TypeAbc]): + self._kind = kind + self._meta_argument_type = meta_argument_type def __eq__(self, other: object) -> bool: if not isinstance(other, Magic): @@ -1548,6 +1788,34 @@ def __eq__(self, other: object) -> bool: and self._meta_argument_type == other._meta_argument_type ) + @classmethod + def from_type_identifier( + cls, + type_identifier: StringReader, + reference_resolver: ReferenceResolver, + cu_hash: bytes, + ) -> Magic: + type_identifier.read("t_magic_") + + for magic_kind in MagicTypeKind: + if type_identifier.startswith(magic_kind): + kind = MagicTypeKind(magic_kind) + type_identifier.read(magic_kind) + + if kind == MagicTypeKind.META_TYPE: + type_identifier.read("_") + meta_argument_type = TypeAbc.from_type_identifier( + type_identifier, reference_resolver, cu_hash + ) + assert meta_argument_type is not None + meta_argument_type = meta_argument_type + else: + meta_argument_type = None + + return Magic(kind, meta_argument_type) + + assert False, f"Unexpected magic type {type_identifier}" + @property def abi_type(self) -> str: raise NotImplementedError @@ -1581,18 +1849,13 @@ class UserDefinedValueType(TypeAbc): def __init__( self, - type_identifier: StringReader, + name: str, + ast_id: AstNodeId, reference_resolver: ReferenceResolver, cu_hash: bytes, ): - type_identifier.read("t_userDefinedValueType") - self._name = _parse_user_identifier(type_identifier) - - match = NUMBER_RE.match(type_identifier.data) - assert match is not None, f"{type_identifier} is not a valid enum" - self._ast_id = AstNodeId(int(match.group("number"))) - type_identifier.read(match.group("number")) - + self._name = name + self._ast_id = ast_id self._reference_resolver = reference_resolver self._cu_hash = cu_hash @@ -1601,6 +1864,23 @@ def __eq__(self, other: object) -> bool: return False return self._name == other._name and self.ir_node == other.ir_node + @classmethod + def from_type_identifier( + cls, + type_identifier: StringReader, + reference_resolver: ReferenceResolver, + cu_hash: bytes, + ) -> UserDefinedValueType: + type_identifier.read("t_userDefinedValueType") + name = _parse_user_identifier(type_identifier) + + match = NUMBER_RE.match(type_identifier.data) + assert match is not None, f"{type_identifier} is not a valid enum" + ast_id = AstNodeId(int(match.group("number"))) + type_identifier.read(match.group("number")) + + return UserDefinedValueType(name, ast_id, reference_resolver, cu_hash) + @property def abi_type(self) -> str: return self.ir_node.underlying_type.type.abi_type @@ -1640,27 +1920,32 @@ class Module(TypeAbc): _cu_hash: bytes def __init__( - self, + self, source_unit_id: int, reference_resolver: ReferenceResolver, cu_hash: bytes + ): + self._source_unit_id = source_unit_id + self._reference_resolver = reference_resolver + self._cu_hash = cu_hash + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Module): + return False + return self.file == other.file + + @classmethod + def from_type_identifier( + cls, type_identifier: StringReader, reference_resolver: ReferenceResolver, cu_hash: bytes, - ): + ) -> Module: type_identifier.read("t_module_") match = NUMBER_RE.match(type_identifier.data) assert match is not None, f"{type_identifier} is not a valid module" - self.__ast_id = AstNodeId(int(match.group("number"))) + source_unit_id = int(match.group("number")) type_identifier.read(match.group("number")) - self._source_unit_id = int(match.group("number")) - - self._reference_resolver = reference_resolver - self._cu_hash = cu_hash - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Module): - return False - return self.file == other.file + return Module(source_unit_id, reference_resolver, cu_hash) @property def abi_type(self) -> str: