Skip to content

Commit

Permalink
modify test_serde to remove None in dicts and keep dict types
Browse files Browse the repository at this point in the history
  • Loading branch information
enitrat committed Feb 5, 2025
1 parent a0647c1 commit 2d1b4c7
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
16 changes: 14 additions & 2 deletions cairo/tests/test_serde.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import ChainMap
from typing import Annotated, Any, List, Mapping, Optional, Set, Tuple, Type, Union

import pytest
Expand Down Expand Up @@ -88,8 +89,7 @@ def get_type(instance: Any) -> Type:
if isinstance(instance, Mapping):
# Get key and value types from the first item in the mapping
if instance:
key_type = get_type(next(iter(instance.keys())))
value_type = get_type(next(iter(instance.values())))
key_type, value_type = instance.__orig_class__.__args__
return Mapping[key_type, value_type]
return Mapping

Expand Down Expand Up @@ -173,6 +173,17 @@ def single_evm_parent(b: Union[Message, Evm]) -> bool:
return True


def remove_none_values(b: Any) -> Any:
"""Recursively remove None values from mappings and their nested structures."""
if isinstance(b, (dict, ChainMap, Mapping)):
return {k: remove_none_values(v) for k, v in b.items() if v is not None}
elif isinstance(b, (list, tuple)):
return type(b)(remove_none_values(x) for x in b)
elif isinstance(b, set):
return {remove_none_values(x) for x in b if x is not None}
return b


class TestSerde:
@given(b=...)
# 20 examples per type
Expand Down Expand Up @@ -283,6 +294,7 @@ def test_type(
assume(no_empty_sequence(b))
assume(single_evm_parent(b))
type_ = get_type(b)
b = remove_none_values(b)
base = segments.gen_arg([gen_arg(type_, b)])
result = serde.serialize(to_cairo_type(type_), base, shift=0)
assert result == b
Expand Down
21 changes: 21 additions & 0 deletions cairo/tests/utils/strategies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# ruff: noqa: E402

import os
from collections import ChainMap
from typing import (
ForwardRef,
Generic,
Expand Down Expand Up @@ -210,6 +211,24 @@ def tuple_strategy(thing):
)


K = TypeVar("K")
V = TypeVar("V")


class TypedDict(dict, Generic[K, V]):
"""A dict that maintains its type information."""

def __new__(cls, values):
return super(TypedDict, cls).__new__(cls, values)


def dict_strategy(thing):
key_type, value_type = thing.__args__
return st.dictionaries(st.from_type(key_type), st.from_type(value_type)).map(
lambda x: TypedDict[key_type, value_type](x)
)


gas_left = st.integers(min_value=0, max_value=BLOCK_GAS_LIMIT).map(Uint)

accessed_addresses = st.sets(st.from_type(Address), max_size=MAX_ADDRESS_SET_SIZE)
Expand Down Expand Up @@ -526,6 +545,8 @@ def register_type_strategies():
st.register_type_strategy(Memory, memory)
st.register_type_strategy(Evm, evm)
st.register_type_strategy(tuple, tuple_strategy)
st.register_type_strategy(dict, dict_strategy)
st.register_type_strategy(ChainMap, dict_strategy)
st.register_type_strategy(State, state)
st.register_type_strategy(TransientStorage, transient_storage)
st.register_type_strategy(MutableBloom, bloom.map(MutableBloom))
Expand Down

0 comments on commit 2d1b4c7

Please sign in to comment.