From 4ffd081a6ace5d55f2300043f6814741e2bf0c53 Mon Sep 17 00:00:00 2001 From: enitrat Date: Tue, 4 Feb 2025 15:56:20 +0000 Subject: [PATCH] feat: Pre feat: prepare_trie cleanup cleanup --- cairo/ethereum/cancun/trie.cairo | 628 +++++++++++++++++- cairo/tests/ethereum/cancun/test_trie.py | 34 +- cairo/tests/utils/args_gen.py | 15 + .../src/cairo_addons/hints/hashdict.py | 2 +- 4 files changed, 656 insertions(+), 23 deletions(-) diff --git a/cairo/ethereum/cancun/trie.cairo b/cairo/ethereum/cancun/trie.cairo index 38c623c2f..c27421b5d 100644 --- a/cairo/ethereum/cancun/trie.cairo +++ b/cairo/ethereum/cancun/trie.cairo @@ -1,4 +1,5 @@ from starkware.cairo.common.cairo_builtins import PoseidonBuiltin +from starkware.cairo.common.default_dict import default_dict_new from starkware.cairo.common.builtin_poseidon.poseidon import poseidon_hash, poseidon_hash_many from starkware.cairo.common.alloc import alloc from starkware.cairo.common.math_cmp import is_le @@ -10,7 +11,7 @@ from starkware.cairo.common.cairo_builtins import KeccakBuiltin from starkware.cairo.common.memcpy import memcpy from src.utils.bytes import uint256_to_bytes32_little -from src.utils.dict import hashdict_read, hashdict_write, dict_new_empty +from src.utils.dict import hashdict_read, hashdict_write, dict_new_empty, dict_squash from ethereum.crypto.hash import keccak256 from ethereum.utils.numeric import min from ethereum_rlp.rlp import encode, _encode_bytes, _encode @@ -18,9 +19,12 @@ from ethereum.utils.numeric import U256__eq__ from ethereum_types.numeric import U256, Uint, bool, U256Struct from ethereum_types.bytes import ( HashedBytes, + HashedBytes32, Bytes, + Bytes20, BytesStruct, Bytes32, + Bytes32Struct, StringStruct, String, MappingBytesBytes, @@ -31,7 +35,9 @@ from ethereum_types.bytes import ( ) from ethereum.cancun.blocks import ( Receipt, + ReceiptStruct, Withdrawal, + WithdrawalStruct, UnionBytesLegacyTransaction, UnionBytesLegacyTransactionEnum, OptionalUnionBytesLegacyTransaction, @@ -55,7 +61,7 @@ from ethereum.cancun.fork_types import ( MappingTupleAddressBytes32U256, MappingTupleAddressBytes32U256Struct, ) -from ethereum.cancun.transactions_types import LegacyTransaction +from ethereum.cancun.transactions_types import LegacyTransaction, LegacyTransactionStruct from ethereum_rlp.rlp import ( Extended, SequenceExtended, @@ -70,6 +76,7 @@ from ethereum_rlp.rlp import ( encode_u256, ) from ethereum.utils.numeric import divmod +from ethereum.utils.bytes import Bytes32_to_Bytes, Bytes20_to_Bytes from cairo_core.comparison import is_zero @@ -168,20 +175,6 @@ namespace InternalNodeImpl { } } -struct NodeEnum { - account: Account, - bytes: Bytes, - legacy_transaction: LegacyTransaction, - receipt: Receipt, - uint: Uint*, - u256: U256, - withdrawal: Withdrawal, -} - -struct Node { - value: NodeEnum*, -} - struct TrieAddressOptionalAccountStruct { secured: bool, default: OptionalAccount, @@ -192,6 +185,10 @@ struct TrieAddressOptionalAccount { value: TrieAddressOptionalAccountStruct*, } +// Internal representation of the Dict[Address, Trie[Bytes32, U256]] +// which holds the storage tries for each account. +// During execution, the storage tries are "merged" into a single trie where the keys are +// the hash of the account address and the storage key. struct TrieTupleAddressBytes32U256Struct { secured: bool, default: U256, @@ -202,6 +199,33 @@ struct TrieTupleAddressBytes32U256 { value: TrieTupleAddressBytes32U256Struct*, } +// To compute storage roots, we will extract mapping of all storage tries for each account. +struct Bytes32U256DictAccess { + key: HashedBytes32, + prev_value: U256, + new_value: U256, +} + +struct MappingBytes32U256Struct { + dict_ptr_start: Bytes32U256DictAccess*, + dict_ptr: Bytes32U256DictAccess*, + parent_dict: MappingBytes32U256Struct*, +} + +struct MappingBytes32U256 { + value: MappingBytes32U256Struct*, +} + +struct TrieBytes32U256Struct { + secured: bool, + default: U256, + _data: MappingBytes32U256, +} + +struct TrieBytes32U256 { + value: TrieBytes32U256Struct*, +} + struct BytesOptionalUnionBytesLegacyTransactionDictAccess { key: HashedBytes, prev_value: OptionalUnionBytesLegacyTransaction, @@ -280,6 +304,32 @@ struct TrieBytesOptionalUnionBytesWithdrawal { value: TrieBytesOptionalUnionBytesWithdrawalStruct*, } +struct UnionEthereumTries { + value: UnionEthereumTriesEnum*, +} + +struct UnionEthereumTriesEnum { + account: TrieAddressOptionalAccount, + storage: TrieBytes32U256, + transaction: TrieBytesOptionalUnionBytesLegacyTransaction, + receipt: TrieBytesOptionalUnionBytesReceipt, + withdrawal: TrieBytesOptionalUnionBytesWithdrawal, +} + +struct NodeEnum { + account: Account, + bytes: Bytes, + legacy_transaction: LegacyTransaction, + receipt: Receipt, + uint: Uint*, + u256: U256, + withdrawal: Withdrawal, +} + +struct Node { + value: NodeEnum*, +} + func encode_internal_node{ range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_ptr: KeccakBuiltin* }(node: InternalNode) -> Extended { @@ -843,6 +893,510 @@ func bytes_to_nibble_list{bitwise_ptr: BitwiseBuiltin*}(bytes_: Bytes) -> Bytes return result; } +func _prepare_trie{ + range_check_ptr, + bitwise_ptr: BitwiseBuiltin*, + keccak_ptr: KeccakBuiltin*, + poseidon_ptr: PoseidonBuiltin*, +}(trie_union: UnionEthereumTries) -> MappingBytesBytes { + alloc_locals; + + let (local mapping_ptr_start: BytesBytesDictAccess*) = default_dict_new(0); + + tempvar is_account = cast(trie_union.value.account.value, felt); + jmp account if is_account != 0; + + tempvar is_storage = cast(trie_union.value.storage.value, felt); + jmp storage if is_storage != 0; + + tempvar is_transaction = cast(trie_union.value.transaction.value, felt); + jmp transaction if is_transaction != 0; + + tempvar is_receipt = cast(trie_union.value.receipt.value, felt); + jmp receipt if is_receipt != 0; + + tempvar is_withdrawal = cast(trie_union.value.withdrawal.value, felt); + jmp withdrawal if is_withdrawal != 0; + + with_attr error_message("Invalid trie union") { + assert 0 = 1; + } + + account: + let account_trie = trie_union.value.account; + _prepare_trie_inner_account( + account_trie, account_trie.value._data.value.dict_ptr_start, mapping_ptr_start + ); + jmp end; + + storage: + let storage_trie = trie_union.value.storage; + _prepare_trie_inner_storage( + storage_trie, storage_trie.value._data.value.dict_ptr_start, mapping_ptr_start + ); + jmp end; + + transaction: + let transaction_trie = trie_union.value.transaction; + _prepare_trie_inner_transaction( + transaction_trie, transaction_trie.value._data.value.dict_ptr_start, mapping_ptr_start + ); + jmp end; + + receipt: + let receipt_trie = trie_union.value.receipt; + _prepare_trie_inner_receipt( + receipt_trie, receipt_trie.value._data.value.dict_ptr_start, mapping_ptr_start + ); + jmp end; + + withdrawal: + let withdrawal_trie = trie_union.value.withdrawal; + _prepare_trie_inner_withdrawal( + withdrawal_trie, withdrawal_trie.value._data.value.dict_ptr_start, mapping_ptr_start + ); + jmp end; + + end: + let range_check_ptr = [ap - 5]; + let bitwise_ptr = cast([ap - 4], BitwiseBuiltin*); + let keccak_ptr = cast([ap - 3], KeccakBuiltin*); + let poseidon_ptr = cast([ap - 2], PoseidonBuiltin*); + let mapping_ptr_end = cast([ap - 1], BytesBytesDictAccess*); + + tempvar result = MappingBytesBytes( + new MappingBytesBytesStruct( + cast(mapping_ptr_start, BytesBytesDictAccess*), + cast(mapping_ptr_end, BytesBytesDictAccess*), + cast(0, MappingBytesBytesStruct*), + ), + ); + return result; +} + +func _prepare_trie_inner_account{ + range_check_ptr, + bitwise_ptr: BitwiseBuiltin*, + keccak_ptr: KeccakBuiltin*, + poseidon_ptr: PoseidonBuiltin*, +}( + trie: TrieAddressOptionalAccount, + dict_ptr: AddressAccountDictAccess*, + mapping_ptr_end: BytesBytesDictAccess*, +) -> BytesBytesDictAccess* { + alloc_locals; + + if (dict_ptr == trie.value._data.value.dict_ptr) { + return mapping_ptr_end; + } + + let preimage = Bytes20_to_Bytes(dict_ptr.key); + let value = dict_ptr.new_value; + // TODO: get storage root + let (buffer: felt*) = alloc(); + tempvar storage_root = Bytes(new BytesStruct(buffer, 0)); + tempvar node = Node( + new NodeEnum( + account=value, + bytes=Bytes(cast(0, BytesStruct*)), + legacy_transaction=LegacyTransaction(cast(0, LegacyTransactionStruct*)), + receipt=Receipt(cast(0, ReceiptStruct*)), + uint=cast(0, Uint*), + u256=U256(cast(0, U256Struct*)), + withdrawal=Withdrawal(cast(0, WithdrawalStruct*)), + ), + ); + let encoded_value = encode_node(node, storage_root); + + if (encoded_value.value.len == 0) { + with_attr error_message("AssertionError") { + assert 0 = 1; + } + } + + // TODO: Common part, factorise. + + if (trie.value.secured.value != 0) { + let key_bytes32 = keccak256(preimage); + let key_bytes = Bytes32_to_Bytes(key_bytes32); + tempvar range_check_ptr = range_check_ptr; + tempvar bitwise_ptr = bitwise_ptr; + tempvar keccak_ptr = keccak_ptr; + } else { + tempvar key_bytes = preimage; + tempvar range_check_ptr = range_check_ptr; + tempvar bitwise_ptr = bitwise_ptr; + tempvar keccak_ptr = keccak_ptr; + } + let key_bytes = Bytes(cast([ap - 4], BytesStruct*)); + let range_check_ptr = [ap - 3]; + let bitwise_ptr = cast([ap - 2], BitwiseBuiltin*); + let keccak_ptr = cast([ap - 1], KeccakBuiltin*); + + let nibbles_list = bytes_to_nibble_list(key_bytes); + let mapping_dict_ptr = cast(mapping_ptr_end, DictAccess*); + hashdict_write{dict_ptr=mapping_dict_ptr}( + nibbles_list.value.len, nibbles_list.value.data, cast(encoded_value.value, felt) + ); + + return _prepare_trie_inner_account( + trie, + dict_ptr + AddressAccountDictAccess.SIZE, + cast(mapping_dict_ptr, BytesBytesDictAccess*), + ); +} + +func _prepare_trie_inner_storage{ + range_check_ptr, + bitwise_ptr: BitwiseBuiltin*, + keccak_ptr: KeccakBuiltin*, + poseidon_ptr: PoseidonBuiltin*, +}( + trie: TrieBytes32U256, dict_ptr: Bytes32U256DictAccess*, mapping_ptr_end: BytesBytesDictAccess* +) -> BytesBytesDictAccess* { + alloc_locals; + + if (dict_ptr == trie.value._data.value.dict_ptr) { + return mapping_ptr_end; + } + + let preimage_b32 = _get_bytes32_preimage_for_key( + dict_ptr.key.value, cast(trie.value._data.value.dict_ptr, DictAccess*) + ); + let preimage = Bytes32_to_Bytes(preimage_b32); + + let value = dict_ptr.new_value; + tempvar node = Node( + new NodeEnum( + account=Account(cast(0, AccountStruct*)), + bytes=Bytes(cast(0, BytesStruct*)), + legacy_transaction=LegacyTransaction(cast(0, LegacyTransactionStruct*)), + receipt=Receipt(cast(0, ReceiptStruct*)), + uint=cast(0, Uint*), + u256=value, + withdrawal=Withdrawal(cast(0, WithdrawalStruct*)), + ), + ); + let encoded_value = encode_node(node, Bytes(cast(0, BytesStruct*))); + + // TODO: Common part, factorise. + if (encoded_value.value.len == 0) { + with_attr error_message("AssertionError") { + assert 0 = 1; + } + } + + if (trie.value.secured.value != 0) { + let key_bytes32 = keccak256(preimage); + let key_bytes = Bytes32_to_Bytes(key_bytes32); + tempvar range_check_ptr = range_check_ptr; + tempvar bitwise_ptr = bitwise_ptr; + tempvar keccak_ptr = keccak_ptr; + } else { + tempvar key_bytes = preimage; + tempvar range_check_ptr = range_check_ptr; + tempvar bitwise_ptr = bitwise_ptr; + tempvar keccak_ptr = keccak_ptr; + } + let key_bytes = Bytes(cast([ap - 4], BytesStruct*)); + let range_check_ptr = [ap - 3]; + let bitwise_ptr = cast([ap - 2], BitwiseBuiltin*); + let keccak_ptr = cast([ap - 1], KeccakBuiltin*); + + let nibbles_list = bytes_to_nibble_list(key_bytes); + let mapping_dict_ptr = cast(mapping_ptr_end, DictAccess*); + hashdict_write{dict_ptr=mapping_dict_ptr}( + nibbles_list.value.len, nibbles_list.value.data, cast(encoded_value.value, felt) + ); + + return _prepare_trie_inner_storage( + trie, dict_ptr + Bytes32U256DictAccess.SIZE, cast(mapping_dict_ptr, BytesBytesDictAccess*) + ); +} + +func _prepare_trie_inner_transaction{ + range_check_ptr, + bitwise_ptr: BitwiseBuiltin*, + keccak_ptr: KeccakBuiltin*, + poseidon_ptr: PoseidonBuiltin*, +}( + trie: TrieBytesOptionalUnionBytesLegacyTransaction, + dict_ptr: BytesOptionalUnionBytesLegacyTransactionDictAccess*, + mapping_ptr_end: BytesBytesDictAccess*, +) -> BytesBytesDictAccess* { + alloc_locals; + + if (dict_ptr == trie.value._data.value.dict_ptr) { + return mapping_ptr_end; + } + + let preimage = _get_bytes_preimage_for_key( + dict_ptr.key.value, cast(trie.value._data.value.dict_ptr, DictAccess*) + ); + let value = dict_ptr.new_value; + + // Skip all None values + if (cast(dict_ptr.new_value.value, felt) == 0) { + return _prepare_trie_inner_transaction( + trie, + dict_ptr + BytesOptionalUnionBytesLegacyTransactionDictAccess.SIZE, + mapping_ptr_end, + ); + } + + // Create the correct node type + + if (dict_ptr.new_value.value.bytes.value != 0) { + tempvar node = Node( + new NodeEnum( + account=Account(cast(0, AccountStruct*)), + bytes=dict_ptr.new_value.value.bytes, + legacy_transaction=LegacyTransaction(cast(0, LegacyTransactionStruct*)), + receipt=Receipt(cast(0, ReceiptStruct*)), + uint=cast(0, Uint*), + u256=U256(cast(0, U256Struct*)), + withdrawal=Withdrawal(cast(0, WithdrawalStruct*)), + ), + ); + } else { + tempvar node = Node( + new NodeEnum( + account=Account(cast(0, AccountStruct*)), + bytes=Bytes(cast(0, BytesStruct*)), + legacy_transaction=dict_ptr.new_value.value.legacy_transaction, + receipt=Receipt(cast(0, ReceiptStruct*)), + uint=cast(0, Uint*), + u256=U256(cast(0, U256Struct*)), + withdrawal=Withdrawal(cast(0, WithdrawalStruct*)), + ), + ); + } + + let encoded_value = encode_node(node, Bytes(cast(0, BytesStruct*))); + + if (encoded_value.value.len == 0) { + with_attr error_message("AssertionError") { + assert 0 = 1; + } + } + + if (trie.value.secured.value != 0) { + let key_bytes32 = keccak256(preimage); + let key_bytes = Bytes32_to_Bytes(key_bytes32); + tempvar range_check_ptr = range_check_ptr; + tempvar bitwise_ptr = bitwise_ptr; + tempvar keccak_ptr = keccak_ptr; + } else { + tempvar key_bytes = preimage; + tempvar range_check_ptr = range_check_ptr; + tempvar bitwise_ptr = bitwise_ptr; + tempvar keccak_ptr = keccak_ptr; + } + let key_bytes = Bytes(cast([ap - 4], BytesStruct*)); + let range_check_ptr = [ap - 3]; + let bitwise_ptr = cast([ap - 2], BitwiseBuiltin*); + let keccak_ptr = cast([ap - 1], KeccakBuiltin*); + + let nibbles_list = bytes_to_nibble_list(key_bytes); + let mapping_dict_ptr = cast(mapping_ptr_end, DictAccess*); + hashdict_write{dict_ptr=mapping_dict_ptr}( + nibbles_list.value.len, nibbles_list.value.data, cast(encoded_value.value, felt) + ); + + return _prepare_trie_inner_transaction( + trie, + dict_ptr + BytesOptionalUnionBytesLegacyTransactionDictAccess.SIZE, + cast(mapping_dict_ptr, BytesBytesDictAccess*), + ); +} + +func _prepare_trie_inner_receipt{ + range_check_ptr, + bitwise_ptr: BitwiseBuiltin*, + keccak_ptr: KeccakBuiltin*, + poseidon_ptr: PoseidonBuiltin*, +}( + trie: TrieBytesOptionalUnionBytesReceipt, + dict_ptr: BytesOptionalUnionBytesReceiptDictAccess*, + mapping_ptr_end: BytesBytesDictAccess*, +) -> BytesBytesDictAccess* { + alloc_locals; + + if (dict_ptr == trie.value._data.value.dict_ptr) { + return mapping_ptr_end; + } + + let preimage = _get_bytes_preimage_for_key( + dict_ptr.key.value, cast(trie.value._data.value.dict_ptr, DictAccess*) + ); + let value = dict_ptr.new_value; + + // Skip all None values + if (cast(dict_ptr.new_value.value, felt) == 0) { + return _prepare_trie_inner_receipt( + trie, dict_ptr + BytesOptionalUnionBytesReceiptDictAccess.SIZE, mapping_ptr_end + ); + } + + // Create the correct node type + + if (dict_ptr.new_value.value.bytes.value != 0) { + tempvar node = Node( + new NodeEnum( + account=Account(cast(0, AccountStruct*)), + bytes=dict_ptr.new_value.value.bytes, + legacy_transaction=LegacyTransaction(cast(0, LegacyTransactionStruct*)), + receipt=Receipt(cast(0, ReceiptStruct*)), + uint=cast(0, Uint*), + u256=U256(cast(0, U256Struct*)), + withdrawal=Withdrawal(cast(0, WithdrawalStruct*)), + ), + ); + } else { + tempvar node = Node( + new NodeEnum( + account=Account(cast(0, AccountStruct*)), + bytes=Bytes(cast(0, BytesStruct*)), + legacy_transaction=LegacyTransaction(cast(0, LegacyTransactionStruct*)), + receipt=dict_ptr.new_value.value.receipt, + uint=cast(0, Uint*), + u256=U256(cast(0, U256Struct*)), + withdrawal=Withdrawal(cast(0, WithdrawalStruct*)), + ), + ); + } + + let encoded_value = encode_node(node, Bytes(cast(0, BytesStruct*))); + + if (encoded_value.value.len == 0) { + with_attr error_message("AssertionError") { + assert 0 = 1; + } + } + + if (trie.value.secured.value != 0) { + let key_bytes32 = keccak256(preimage); + let key_bytes = Bytes32_to_Bytes(key_bytes32); + tempvar range_check_ptr = range_check_ptr; + tempvar bitwise_ptr = bitwise_ptr; + tempvar keccak_ptr = keccak_ptr; + } else { + tempvar key_bytes = preimage; + tempvar range_check_ptr = range_check_ptr; + tempvar bitwise_ptr = bitwise_ptr; + tempvar keccak_ptr = keccak_ptr; + } + let key_bytes = Bytes(cast([ap - 4], BytesStruct*)); + let range_check_ptr = [ap - 3]; + let bitwise_ptr = cast([ap - 2], BitwiseBuiltin*); + let keccak_ptr = cast([ap - 1], KeccakBuiltin*); + + let nibbles_list = bytes_to_nibble_list(key_bytes); + let mapping_dict_ptr = cast(mapping_ptr_end, DictAccess*); + hashdict_write{dict_ptr=mapping_dict_ptr}( + nibbles_list.value.len, nibbles_list.value.data, cast(encoded_value.value, felt) + ); + + return _prepare_trie_inner_receipt( + trie, + dict_ptr + BytesOptionalUnionBytesReceiptDictAccess.SIZE, + cast(mapping_dict_ptr, BytesBytesDictAccess*), + ); +} + +func _prepare_trie_inner_withdrawal{ + range_check_ptr, + bitwise_ptr: BitwiseBuiltin*, + keccak_ptr: KeccakBuiltin*, + poseidon_ptr: PoseidonBuiltin*, +}( + trie: TrieBytesOptionalUnionBytesWithdrawal, + dict_ptr: BytesOptionalUnionBytesWithdrawalDictAccess*, + mapping_ptr_end: BytesBytesDictAccess*, +) -> BytesBytesDictAccess* { + alloc_locals; + + if (dict_ptr == trie.value._data.value.dict_ptr) { + return mapping_ptr_end; + } + + let preimage = _get_bytes_preimage_for_key( + dict_ptr.key.value, cast(trie.value._data.value.dict_ptr, DictAccess*) + ); + let value = dict_ptr.new_value; + + // Skip all None values + if (cast(dict_ptr.new_value.value, felt) == 0) { + return _prepare_trie_inner_withdrawal( + trie, dict_ptr + BytesOptionalUnionBytesWithdrawalDictAccess.SIZE, mapping_ptr_end + ); + } + + // Create the correct node type + if (dict_ptr.new_value.value.bytes.value != 0) { + tempvar node = Node( + new NodeEnum( + account=Account(cast(0, AccountStruct*)), + bytes=dict_ptr.new_value.value.bytes, + legacy_transaction=LegacyTransaction(cast(0, LegacyTransactionStruct*)), + receipt=Receipt(cast(0, ReceiptStruct*)), + uint=cast(0, Uint*), + u256=U256(cast(0, U256Struct*)), + withdrawal=Withdrawal(cast(0, WithdrawalStruct*)), + ), + ); + } else { + tempvar node = Node( + new NodeEnum( + account=Account(cast(0, AccountStruct*)), + bytes=Bytes(cast(0, BytesStruct*)), + legacy_transaction=LegacyTransaction(cast(0, LegacyTransactionStruct*)), + receipt=Receipt(cast(0, ReceiptStruct*)), + uint=cast(0, Uint*), + u256=U256(cast(0, U256Struct*)), + withdrawal=dict_ptr.new_value.value.withdrawal, + ), + ); + } + + let encoded_value = encode_node(node, Bytes(cast(0, BytesStruct*))); + + if (encoded_value.value.len == 0) { + with_attr error_message("AssertionError") { + assert 0 = 1; + } + } + + if (trie.value.secured.value != 0) { + let key_bytes32 = keccak256(preimage); + let key_bytes = Bytes32_to_Bytes(key_bytes32); + tempvar range_check_ptr = range_check_ptr; + tempvar bitwise_ptr = bitwise_ptr; + tempvar keccak_ptr = keccak_ptr; + } else { + tempvar key_bytes = preimage; + tempvar range_check_ptr = range_check_ptr; + tempvar bitwise_ptr = bitwise_ptr; + tempvar keccak_ptr = keccak_ptr; + } + let key_bytes = Bytes(cast([ap - 4], BytesStruct*)); + let range_check_ptr = [ap - 3]; + let bitwise_ptr = cast([ap - 2], BitwiseBuiltin*); + let keccak_ptr = cast([ap - 1], KeccakBuiltin*); + + let nibbles_list = bytes_to_nibble_list(key_bytes); + let mapping_dict_ptr = cast(mapping_ptr_end, DictAccess*); + hashdict_write{dict_ptr=mapping_dict_ptr}( + nibbles_list.value.len, nibbles_list.value.data, cast(encoded_value.value, felt) + ); + + return _prepare_trie_inner_withdrawal( + trie, + dict_ptr + BytesOptionalUnionBytesWithdrawalDictAccess.SIZE, + cast(mapping_dict_ptr, BytesBytesDictAccess*), + ); +} + // func _prepare_trie(trie: Trie[K, V], get_storage_root: Callable[List(elts=[Name(id='Address', ctx=Load())], ctx=Load()), Root]) -> Mapping[Bytes, Bytes] { // // Implementation: // // mapped: MutableMapping[Bytes, Bytes] = {} @@ -935,7 +1489,7 @@ func _search_common_prefix_length{ return current_length; } - let preimage = _get_preimage_for_key(obj.key.value, dict_ptr_stop); + let preimage = _get_bytes_preimage_for_key(obj.key.value, cast(dict_ptr_stop, DictAccess*)); tempvar sliced_key = Bytes( new BytesStruct(preimage.value.data + level.value, preimage.value.len - level.value) ); @@ -961,7 +1515,9 @@ func _get_branch_for_nibble_at_level_inner{poseidon_ptr: PoseidonBuiltin*}( return (branch_ptr, value); } - let preimage = _get_preimage_for_key(dict_ptr.key.value, dict_ptr_stop); + let preimage = _get_bytes_preimage_for_key( + dict_ptr.key.value, cast(dict_ptr_stop, DictAccess*) + ); // Check cases let is_value_case = is_zero(preimage.value.len - level); @@ -1185,8 +1741,8 @@ func _get_branches{poseidon_ptr: PoseidonBuiltin*}(obj: MappingBytesBytes, level // The preimage is validated to be correctly provided by the prover by hashing it and comparing it to the key. // @param key - The key to get the preimage for. Either a hashed or non-hashed key - but it must be a felt. // @param dict_ptr_stop - The pointer to the end of the dict segment, the one registered in the tracker. -func _get_preimage_for_key{poseidon_ptr: PoseidonBuiltin*}( - key: felt, dict_ptr_stop: BytesBytesDictAccess* +func _get_bytes_preimage_for_key{poseidon_ptr: PoseidonBuiltin*}( + key: felt, dict_ptr_stop: DictAccess* ) -> Bytes { alloc_locals; @@ -1214,6 +1770,36 @@ func _get_preimage_for_key{poseidon_ptr: PoseidonBuiltin*}( return res; } +// @notice Given a key (inside `dict_ptr`), returns the bytes32 preimage of the key registered in the tracker. +// The preimage is validated to be correctly provided by the prover by hashing it and comparing it to the key. +// @param key - The key to get the preimage for. Either a hashed or non-hashed key - but it must be a felt. +// @param dict_ptr_stop - The pointer to the end of the dict segment, the one registered in the tracker. +func _get_bytes32_preimage_for_key{poseidon_ptr: PoseidonBuiltin*}( + key: felt, dict_ptr_stop: DictAccess* +) -> Bytes32 { + alloc_locals; + + // Get preimage data + let (local preimage_data: felt*) = alloc(); + local preimage_len; + %{ get_preimage_for_key %} + + // Verify preimage + if (preimage_len != 2) { + with_attr error_message("_get_bytes32_preimage_for_key: preimage_len != 2") { + assert 0 = 1; + } + } + + let (preimage_hash) = poseidon_hash_many(preimage_len, preimage_data); + with_attr error_message("preimage_hash != key") { + assert preimage_hash = key; + } + + tempvar res = Bytes32(new Bytes32Struct(preimage_data[0], preimage_data[1])); + return res; +} + // @dev The obj mapping needs to be squashed before calling this function. // @dev No other squashing is required after this function returns as it only reads from the DictAccess segment. // @dev This function could be made faster by sorting the DictAccess segment by key before processing it. @@ -1233,7 +1819,7 @@ func patricialize{ let arbitrary_value = obj.value.dict_ptr_start.new_value; let current_key = obj.value.dict_ptr_start.key.value; - let preimage = _get_preimage_for_key(current_key, obj.value.dict_ptr); + let preimage = _get_bytes_preimage_for_key(current_key, cast(obj.value.dict_ptr, DictAccess*)); // if leaf node if (len == 1) { diff --git a/cairo/tests/ethereum/cancun/test_trie.py b/cairo/tests/ethereum/cancun/test_trie.py index 64223d6b5..978a2cf86 100644 --- a/cairo/tests/ethereum/cancun/test_trie.py +++ b/cairo/tests/ethereum/cancun/test_trie.py @@ -7,6 +7,7 @@ InternalNode, Node, Trie, + _prepare_trie, bytes_to_nibble_list, common_prefix_length, copy_trie, @@ -24,7 +25,7 @@ from cairo_addons.testing.hints import patch_hint from tests.utils.assertion import sequence_equal -from tests.utils.errors import cairo_error +from tests.utils.errors import cairo_error, strict_raises from tests.utils.strategies import bytes32, nibble, uint4 @@ -140,6 +141,37 @@ def test_get_branches(self, cairo_run, obj, level): def test_patricialize(self, cairo_run, obj: Mapping[Bytes, Bytes]): assert patricialize(obj, Uint(0)) == cairo_run("patricialize", obj, Uint(0)) + @given(trie=...) + def test_prepare_trie( + self, + cairo_run, + trie: Union[ + Trie[Bytes, Optional[Union[Bytes, Withdrawal]]], # Withdrawal Trie + Trie[Bytes, Optional[Union[Bytes, LegacyTransaction]]], # Transaction Trie + Trie[Bytes, Optional[Union[Bytes, Receipt]]], # Receipt Trie + Trie[Address, Optional[Account]], # Account Trie + Trie[Bytes32, U256], # Storage Trie + ], + ): + # TODO: compute storage root + key_type, _ = trie.__orig_class__.__args__ + if key_type is Address: + + def get_storage_root(_address): + return b"" + + else: + get_storage_root = None + + try: + result_cairo = cairo_run("_prepare_trie", trie, get_storage_root) + except Exception as e: + with strict_raises(type(e)): + _prepare_trie(trie, get_storage_root) + return + + assert result_cairo == _prepare_trie(trie, get_storage_root) + class TestTrieOperations: class TestGet: diff --git a/cairo/tests/utils/args_gen.py b/cairo/tests/utils/args_gen.py index 6fc2d15d6..cf9cca0cc 100644 --- a/cairo/tests/utils/args_gen.py +++ b/cairo/tests/utils/args_gen.py @@ -635,6 +635,17 @@ def __eq__(self, other): ("ethereum", "cancun", "vm", "Stack"): Stack[U256], ("ethereum", "cancun", "vm", "gas", "ExtendMemory"): ExtendMemory, ("ethereum", "cancun", "vm", "interpreter", "MessageCallOutput"): MessageCallOutput, + # Union of all possible trie types as defined in the ethereum spec. + # Does not take into account our internal trie where we merged accounts and storage. + # ! Order matters here. + ("ethereum", "cancun", "trie", "UnionEthereumTries"): Union[ + Trie[Address, Optional[Account]], + Trie[Bytes32, U256], + Trie[Bytes, Optional[Union[Bytes, LegacyTransaction]]], + Trie[Bytes, Optional[Union[Bytes, Receipt]]], + Trie[Bytes, Optional[Union[Bytes, Withdrawal]]], + ], + ("ethereum", "cancun", "trie", "TrieBytes32U256"): Trie[Bytes32, U256], **vm_exception_mappings, **ethereum_exception_mappings, # For tests only @@ -664,6 +675,10 @@ def isinstance_with_generic(obj, type_hint): if origin is abc.Sequence: return type(obj) in (list, tuple) + if origin is Trie: + key_type, value_type = obj.__orig_class__.__args__ + return origin[key_type, value_type] == type_hint + return isinstance(obj, origin) diff --git a/python/cairo-addons/src/cairo_addons/hints/hashdict.py b/python/cairo-addons/src/cairo_addons/hints/hashdict.py index 1c716c447..c6287b15a 100644 --- a/python/cairo-addons/src/cairo_addons/hints/hashdict.py +++ b/python/cairo-addons/src/cairo_addons/hints/hashdict.py @@ -77,7 +77,7 @@ def get_preimage_for_key( ): from cairo_addons.hints.hashdict import _get_preimage_for_hashed_key - preimage = bytes( + preimage = list( _get_preimage_for_hashed_key( ids.key, dict_manager.get_tracker(ids.dict_ptr_stop) )