Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: replace string types with force-validated pydantic types #59

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 48 additions & 42 deletions eip712/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,52 @@
# Collection of commonly-used EIP712 message type definitions
from typing import Optional, Type, Union

from eth_pydantic_types.abi import address, bytes, bytes32, string, uint8, uint256

from .messages import EIP712Message


class EIP2612(EIP712Message):
# NOTE: Subclass this w/ at least one header field

owner: "address" # type: ignore
spender: "address" # type: ignore
value: "uint256" # type: ignore
nonce: "uint256" # type: ignore
deadline: "uint256" # type: ignore
owner: address
spender: address
value: uint256
nonce: uint256
deadline: uint256


class EIP4494(EIP712Message):
# NOTE: Subclass this w/ at least one header field

spender: "address" # type: ignore
tokenId: "uint256" # type: ignore
nonce: "uint256" # type: ignore
deadline: "uint256" # type: ignore
spender: address
tokenId: uint256
nonce: uint256
deadline: uint256


def create_permit_def(eip=2612, **header_fields):
if eip == 2612:

class Permit(EIP2612):
_name_ = header_fields.get("name", None)
_version_ = header_fields.get("version", None)
_chainId_ = header_fields.get("chainId", None)
_verifyingContract_ = header_fields.get("verifyingContract", None)
_salt_ = header_fields.get("salt", None)
eip712_name_: Optional[string] = header_fields.get("name", None)
eip712_version_: Optional[string] = header_fields.get("version", None)
eip712_chainId_: Optional[uint256] = header_fields.get("chainId", None)
eip712_verifyingContract_: Optional[string] = header_fields.get(
"verifyingContract", None
)
eip712_salt_: Optional[bytes32] = header_fields.get("salt", None)

elif eip == 4494:

class Permit(EIP4494):
_name_ = header_fields.get("name", None)
_version_ = header_fields.get("version", None)
_chainId_ = header_fields.get("chainId", None)
_verifyingContract_ = header_fields.get("verifyingContract", None)
_salt_ = header_fields.get("salt", None)
eip712_name_: Optional[string] = header_fields.get("name", None)
eip712_version_: Optional[string] = header_fields.get("version", None)
eip712_chainId_: Optional[uint256] = header_fields.get("chainId", None)
eip712_verifyingContract_: Optional[string] = header_fields.get(
"verifyingContract", None
)
eip712_salt_: Optional[bytes32] = header_fields.get("salt", None)

else:
raise ValueError(f"Invalid eip {eip}, must use one of: {EIP2612}, {EIP4494}")
Expand All @@ -51,30 +57,30 @@ class Permit(EIP4494):

class SafeTxV1(EIP712Message):
# NOTE: Subclass this as `SafeTx` w/ at least one header field
to: "address" # type: ignore
value: "uint256" = 0 # type: ignore
data: "bytes" = b""
operation: "uint8" = 0 # type: ignore
safeTxGas: "uint256" = 0 # type: ignore
dataGas: "uint256" = 0 # type: ignore
gasPrice: "uint256" = 0 # type: ignore
gasToken: "address" = "0x0000000000000000000000000000000000000000" # type: ignore
refundReceiver: "address" = "0x0000000000000000000000000000000000000000" # type: ignore
nonce: "uint256" # type: ignore
to: address
value: uint256 = 0
data: bytes = b""
operation: uint8 = 0
safeTxGas: uint256 = 0
dataGas: uint256 = 0
gasPrice: uint256 = 0
gasToken: address = "0x0000000000000000000000000000000000000000"
refundReceiver: address = "0x0000000000000000000000000000000000000000"
nonce: uint256


class SafeTxV2(EIP712Message):
# NOTE: Subclass this as `SafeTx` w/ at least one header field
to: "address" # type: ignore
value: "uint256" = 0 # type: ignore
data: "bytes" = b""
operation: "uint8" = 0 # type: ignore
safeTxGas: "uint256" = 0 # type: ignore
baseGas: "uint256" = 0 # type: ignore
gasPrice: "uint256" = 0 # type: ignore
gasToken: "address" = "0x0000000000000000000000000000000000000000" # type: ignore
refundReceiver: "address" = "0x0000000000000000000000000000000000000000" # type: ignore
nonce: "uint256" # type: ignore
to: address
value: uint256 = 0
data: bytes = b""
operation: uint8 = 0
safeTxGas: uint256 = 0
baseGas: uint256 = 0
gasPrice: uint256 = 0
gasToken: address = "0x0000000000000000000000000000000000000000"
refundReceiver: address = "0x0000000000000000000000000000000000000000"
nonce: uint256


SafeTx = Union[SafeTxV1, SafeTxV2]
Expand All @@ -97,15 +103,15 @@ def create_safe_tx_def(
if minor < 3:

class SafeTx(SafeTxV1):
_verifyingContract_ = contract_address
eip712_verifyingContract_: address = contract_address

elif not chain_id:
raise ValueError("Must supply 'chain_id=' for Safe versions 1.3.0 or later")

else:

class SafeTx(SafeTxV2): # type: ignore[no-redef]
_chainId_ = chain_id
_verifyingContract_ = contract_address
eip712_chainId_: uint256 = chain_id
eip712_verifyingContract_: address = contract_address

return SafeTx
126 changes: 78 additions & 48 deletions eip712/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
Message classes for typed structured data hashing and signing in Ethereum.
"""

from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional

from dataclassy import asdict, dataclass, fields
from eth_abi.abi import is_encodable_type # type: ignore[import-untyped]
from eth_account.messages import SignableMessage, hash_domain, hash_eip712_message
from eth_pydantic_types import Address, HexBytes
from eth_utils import keccak
from eth_utils.curried import ValidationError
from hexbytes import HexBytes
from pydantic import BaseModel, model_validator
from typing_extensions import _AnnotatedAlias

if TYPE_CHECKING:
from eth_pydantic_types.abi import bytes32, string, uint256

# ! Do not change the order of the fields in this list !
# To correctly encode and hash the domain fields, they
Expand All @@ -30,8 +34,7 @@
]


@dataclass(iter=True, slots=True, kwargs=True, kw_only=True)
class EIP712Type:
class EIP712Type(BaseModel):
"""
Dataclass for `EIP-712 <https://eips.ethereum.org/EIPS/eip-712>`__ structured data types
(i.e. the contents of an :class:`EIP712Message`).
Expand All @@ -48,38 +51,65 @@ def _types_(self) -> dict:
"""
types: dict[str, list] = {repr(self): []}

for field in fields(self.__class__):
for field in {
k: v.annotation.__name__ # type: ignore[union-attr]
for k, v in self.model_fields.items()
if not k.startswith("eip712_")
}:
value = getattr(self, field)
if isinstance(value, EIP712Type):
types[repr(self)].append({"name": field, "type": repr(value)})
types.update(value._types_)
else:
# TODO: Use proper ABI typing, not strings
field_type = self.__annotations__[field]
field_type = search_annotations(self, field)

# If the field type is a string, validate through eth-abi
if isinstance(field_type, str):
if not is_encodable_type(field_type):
raise ValidationError(f"'{field}: {field_type}' is not a valid ABI type")
raise ValidationError(f"'{field}: {field_type}' is not a valid ABI Type")

elif issubclass(field_type, EIP712Type):
elif isinstance(field_type, type) and issubclass(field_type, EIP712Type):
field_type = repr(field_type)

else:
raise ValidationError(
f"'{field}' type annotation must either be a subclass of "
f"`EIP712Type` or valid ABI Type string, not {field_type.__name__}"
)
try:
# If field type already has validators or is a known type
# can confirm that type name will be correct
if isinstance(field_type.__value__, _AnnotatedAlias) or issubclass(
field_type.__value__, (Address, HexBytes)
):
field_type = field_type.__name__

except AttributeError:
raise ValidationError(
f"'{field}' type annotation must either be a subclass of "
f"`EIP712Type` or valid ABI Type, not {field_type.__name__}"
)

types[repr(self)].append({"name": field, "type": field_type})

return types

def __getitem__(self, key: str) -> Any:
if (key.startswith("_") and key.endswith("_")) or key not in fields(self.__class__):
if (key.startswith("eip712_") and key.endswith("_")) or key not in self.model_fields:
raise KeyError("Cannot look up header fields or other attributes this way")

return getattr(self, key)

def _prepare_data_for_hashing(self, data: dict) -> dict:
result: dict = {}

for key, value in data.items():
item: Any = value
if isinstance(value, EIP712Type):
item = value.model_dump(mode="json")
elif isinstance(value, dict):
item = self._prepare_data_for_hashing(item)

result[key] = item

return result


class EIP712Message(EIP712Type):
"""
Expand All @@ -88,33 +118,38 @@ class EIP712Message(EIP712Type):
"""

# NOTE: Must override at least one of these fields
_name_: Optional[str] = None
_version_: Optional[str] = None
_chainId_: Optional[int] = None
_verifyingContract_: Optional[str] = None
_salt_: Optional[bytes] = None

def __post_init__(self):
eip712_name_: Optional["string"] = None
eip712_version_: Optional["string"] = None
eip712_chainId_: Optional["uint256"] = None
eip712_verifyingContract_: Optional["string"] = None
eip712_salt_: Optional["bytes32"] = None

@model_validator(mode="after")
@classmethod
def validate_model(cls, value):
# At least one of the header fields must be in the EIP712 message header
if not any(getattr(self, f"_{field}_") for field in EIP712_DOMAIN_FIELDS):
if not any(f"eip712_{field}_" in value.__annotations__ for field in EIP712_DOMAIN_FIELDS):
raise ValidationError(
f"EIP712 Message definition '{repr(self)}' must define "
f"at least one of: _{'_, _'.join(EIP712_DOMAIN_FIELDS)}_"
f"EIP712 Message definition '{repr(cls)}' must define "
f"at least one of: eip712_{'_, eip712_'.join(EIP712_DOMAIN_FIELDS)}_"
)
return value

@property
def _domain_(self) -> dict:
"""The EIP-712 domain structure to be used for serialization and hashing."""
domain_type = [
{"name": field, "type": abi_type}
for field, abi_type in EIP712_DOMAIN_FIELDS.items()
if getattr(self, f"_{field}_")
if getattr(self, f"eip712_{field}_")
]
return {
"types": {
"EIP712Domain": domain_type,
},
"domain": {field["name"]: getattr(self, f"_{field['name']}_") for field in domain_type},
"domain": {
field["name"]: getattr(self, f"eip712_{field['name']}_") for field in domain_type
},
}

@property
Expand All @@ -126,9 +161,10 @@ def _body_(self) -> dict:
"types": dict(self._types_, **self._domain_["types"]),
"primaryType": repr(self),
"message": {
# TODO use __pydantic_extra__ instead
key: getattr(self, key)
for key in fields(self.__class__)
if not key.startswith("_") or not key.endswith("_")
for key in self.model_fields
if not key.startswith("eip712_") or not key.endswith("_")
},
}

Expand All @@ -144,30 +180,24 @@ def signable_message(self) -> SignableMessage:
The current message as a :class:`SignableMessage` named tuple instance.
**NOTE**: The 0x19 prefix is NOT included.
"""
domain = _prepare_data_for_hashing(self._domain_["domain"])
types = _prepare_data_for_hashing(self._types_)
message = _prepare_data_for_hashing(self._body_["message"])
domain = self._prepare_data_for_hashing(self._domain_["domain"])
types = self._prepare_data_for_hashing(self._types_)
message = self._prepare_data_for_hashing(self._body_["message"])
messagebytes = HexBytes(1)
messageDomain = HexBytes(hash_domain(domain))
messageEIP = HexBytes(hash_eip712_message(types, message))
return SignableMessage(
HexBytes(1),
HexBytes(hash_domain(domain)),
HexBytes(hash_eip712_message(types, message)),
messagebytes,
messageDomain,
messageEIP,
)


def calculate_hash(msg: SignableMessage) -> HexBytes:
return HexBytes(keccak(b"".join([bytes.fromhex("19"), *msg])))


def _prepare_data_for_hashing(data: dict) -> dict:
result: dict = {}

for key, value in data.items():
item: Any = value
if isinstance(value, EIP712Type):
item = asdict(value)
elif isinstance(value, dict):
item = _prepare_data_for_hashing(item)

result[key] = item

return result
def search_annotations(cls, field: str) -> Any:
if hasattr(cls, "__annotations__") and field in cls.__annotations__:
return cls.__annotations__[field]
return search_annotations(super(cls.__class__, cls), field)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@
url="https://github.com/ApeWorX/eip712",
include_package_data=True,
install_requires=[
"dataclassy>=0.11.1,<1",
"eth-abi>=5.1.0,<6",
"eth-account>=0.11.3,<0.14",
"eth-pydantic-types>=0.2.0,<1",
"eth-typing>=3.5.2,<6",
"eth-utils>=2.3.1,<6",
"hexbytes>=0.3.1,<2",
Expand Down
Loading
Loading