diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py new file mode 100644 index 0000000..667dc27 --- /dev/null +++ b/remerkleable/stable_container.py @@ -0,0 +1,710 @@ +# This file implements `StableContainer` according to https://eips.ethereum.org/EIPS/eip-7495 +# The EIP is still under review, functionality may change or go away without deprecation. + +import io +from typing import Any, BinaryIO, Dict, List as PyList, Optional, Tuple, \ + TypeVar, Type, Union as PyUnion, \ + get_args, get_origin +from textwrap import indent +from remerkleable.basic import boolean, uint8, uint16, uint32, uint64, uint128, uint256 +from remerkleable.bitfields import Bitlist, Bitvector +from remerkleable.byte_arrays import ByteList, ByteVector +from remerkleable.complex import ComplexView, Container, FieldOffset, List, Vector, \ + decode_offset, encode_offset +from remerkleable.core import View, ViewHook, ViewMeta, OFFSET_BYTE_LENGTH +from remerkleable.tree import Gindex, NavigationError, Node, PairNode, \ + get_depth, subtree_fill_to_contents, zero_node, \ + RIGHT_GINDEX + +N = TypeVar('N', bound=int) +SV = TypeVar('SV', bound='ComplexView') +BV = TypeVar('BV', bound='ComplexView') + + +def stable_get(self, findex, ftyp, n): + if not self.active_fields().get(findex): + return None + data = self.get_backing().get_left() + fnode = data.getter(2**get_depth(n) + findex) + return ftyp.view_from_backing(fnode) + + +def stable_set(self, findex, ftyp, n, value): + next_backing = self.get_backing() + + active_fields = self.active_fields() + active_fields.set(findex, value is not None) + next_backing = next_backing.rebind_right(active_fields.get_backing()) + + if value is not None: + if isinstance(value, ftyp): + fnode = value.get_backing() + else: + fnode = ftyp.coerce_view(value).get_backing() + else: + fnode = zero_node(0) + data = next_backing.get_left() + next_data = data.setter(2**get_depth(n) + findex)(fnode) + next_backing = next_backing.rebind_left(next_data) + + self.set_backing(next_backing) + + +def field_val_repr(self, fkey: str, ftyp: Type[View], fopt: bool) -> str: + field_start = ' ' + fkey + ': ' + ( + ('Optional[' if fopt else '') + ftyp.__name__ + (']' if fopt else '') + ) + ' = ' + try: + field_repr = getattr(self, fkey).__repr__() + if '\n' in field_repr: # if multiline, indent it, but starting from the value. + i = field_repr.index('\n') + field_repr = field_repr[:i+1] + indent(field_repr[i+1:], ' ' * len(field_start)) + return field_start + field_repr + except NavigationError: + return f'{field_start} *omitted*' + + +class StableContainer(ComplexView): + __slots__ = '_field_indices', 'N' + _field_indices: Dict[str, Tuple[int, Type[View]]] + N: int + + def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None, **kwargs): + if backing is not None: + if len(kwargs) != 0: + raise Exception('Cannot have both a backing and elements to init fields') + return super().__new__(cls, backing=backing, hook=hook, **kwargs) + + input_nodes = [] + active_fields = Bitvector[cls.N]() + for fkey, (findex, ftyp) in cls._field_indices.items(): + fnode: Node + finput = kwargs.pop(fkey) if fkey in kwargs else None + if finput is None: + fnode = zero_node(0) + active_fields.set(findex, False) + else: + if isinstance(finput, View): + fnode = finput.get_backing() + else: + fnode = ftyp.coerce_view(finput).get_backing() + active_fields.set(findex, True) + input_nodes.append(fnode) + if len(kwargs) > 0: + raise AttributeError(f'Fields [{"".join(kwargs.keys())}] unknown in `{cls.__name__}`') + + backing = PairNode( + left=subtree_fill_to_contents(input_nodes, get_depth(cls.N)), + right=active_fields.get_backing(), + ) + return super().__new__(cls, backing=backing, hook=hook, **kwargs) + + def __init_subclass__(cls, **kwargs): + if 'n' not in kwargs: + raise TypeError(f'Missing capacity: `{cls.__name__}(StableContainer)`') + n = kwargs.pop('n') + if not isinstance(n, int): + raise TypeError(f'Invalid capacity: `StableContainer[{n}]`') + if n <= 0: + raise TypeError(f'Unsupported capacity: `StableContainer[{n}]`') + cls.N = n + + def __class_getitem__(cls, n: int) -> Type['StableContainer']: + class StableContainerMeta(ViewMeta): + def __new__(cls, name, bases, dct): + return super().__new__(cls, name, bases, dct, n=n) + + class StableContainerView(StableContainer, metaclass=StableContainerMeta): + def __init_subclass__(cls, **kwargs): + if 'N' in cls.__dict__: + raise TypeError(f'Cannot override `N` inside `{cls.__name__}`') + cls._field_indices = {} + for findex, (fkey, t) in enumerate(cls.__annotations__.items()): + if ( + get_origin(t) != PyUnion + or len(get_args(t)) != 2 + or type(None) not in get_args(t) + ): + raise TypeError( + f'`StableContainer` fields must be `Optional[T]` ' + f'but `{cls.__name__}.{fkey}` has type `{t.__name__}`' + ) + ftyp = get_args(t)[0] if get_args(t)[0] is not type(None) else get_args(t)[1] + cls._field_indices[fkey] = (findex, ftyp) + if len(cls._field_indices) > cls.N: + raise TypeError( + f'`{cls.__name__}` is `StableContainer[{cls.N}]` ' + f'but contains {len(cls._field_indices)} fields' + ) + + StableContainerView.__name__ = StableContainerView.type_repr() + return StableContainerView + + @classmethod + def fields(cls) -> Dict[str, Type[View]]: + return { fkey: ftyp for fkey, (_, ftyp) in cls._field_indices.items() } + + @classmethod + def is_fixed_byte_length(cls) -> bool: + return False + + @classmethod + def min_byte_length(cls) -> int: + return Bitvector[cls.N].type_byte_length() + + @classmethod + def max_byte_length(cls) -> int: + total = Bitvector[cls.N].type_byte_length() + for (_, ftyp) in cls._field_indices.values(): + if not ftyp.is_fixed_byte_length(): + total += OFFSET_BYTE_LENGTH + total += ftyp.max_byte_length() + return total + + @classmethod + def is_packed(cls) -> bool: + return False + + @classmethod + def tree_depth(cls) -> int: + return get_depth(cls.N) + + @classmethod + def item_elem_cls(cls, i: int) -> Type[View]: + return list(cls._field_indices.values())[i] + + def active_fields(self) -> Bitvector: + active_fields_node = super().get_backing().get_right() + return Bitvector[self.__class__.N].view_from_backing(active_fields_node) + + def __getattribute__(self, item): + if item == 'N': + raise AttributeError(f'Use `.__class__.{item}` to access `{item}`') + return object.__getattribute__(self, item) + + def __getattr__(self, item): + if item[0] == '_': + return super().__getattribute__(item) + else: + try: + (findex, ftyp) = self.__class__._field_indices[item] + except KeyError: + raise AttributeError(f'Unknown field `{item}`') + + return stable_get(self, findex, ftyp, self.__class__.N) + + def __setattr__(self, key, value): + if key[0] == '_': + super().__setattr__(key, value) + else: + try: + (findex, ftyp) = self.__class__._field_indices[key] + except KeyError: + raise AttributeError(f'Unknown field `{key}`') + + stable_set(self, findex, ftyp, self.__class__.N, value) + + def __repr__(self): + return f'{self.__class__.type_repr()}:\n' + '\n'.join( + indent(field_val_repr(self, fkey, ftyp, fopt=True), ' ') + for fkey, (_, ftyp) in self.__class__._field_indices.items()) + + @classmethod + def type_repr(cls) -> str: + return f'StableContainer[{cls.N}]' + + @classmethod + def deserialize(cls: Type[SV], stream: BinaryIO, scope: int) -> SV: + num_prefix_bytes = Bitvector[cls.N].type_byte_length() + if scope < num_prefix_bytes: + raise ValueError(f'Scope too small for `StableContainer[{cls.N}]` active fields') + active_fields = Bitvector[cls.N].deserialize(stream, num_prefix_bytes) + scope = scope - num_prefix_bytes + + for findex in range(len(cls._field_indices), cls.N): + if active_fields.get(findex): + raise Exception(f'Unknown field index {findex}') + + field_values: Dict[str, View] = {} + dyn_fields: PyList[FieldOffset] = [] + fixed_size = 0 + for fkey, (findex, ftyp) in cls._field_indices.items(): + if not active_fields.get(findex): + continue + if ftyp.is_fixed_byte_length(): + fsize = ftyp.type_byte_length() + field_values[fkey] = ftyp.deserialize(stream, fsize) + fixed_size += fsize + else: + dyn_fields.append(FieldOffset( + key=fkey, typ=ftyp, offset=int(decode_offset(stream)))) + fixed_size += OFFSET_BYTE_LENGTH + if len(dyn_fields) > 0: + if dyn_fields[0].offset < fixed_size: + raise Exception(f'First offset {dyn_fields[0].offset} is ' + f'smaller than expected fixed size {fixed_size}') + for i, (fkey, ftyp, foffset) in enumerate(dyn_fields): + next_offset = dyn_fields[i + 1].offset if i + 1 < len(dyn_fields) else scope + if foffset > next_offset: + raise Exception(f'Offset {i} is invalid: {foffset} ' + f'larger than next offset {next_offset}') + fsize = next_offset - foffset + f_min_size, f_max_size = ftyp.min_byte_length(), ftyp.max_byte_length() + if not (f_min_size <= fsize <= f_max_size): + raise Exception(f'Offset {i} is invalid, size out of bounds: ' + f'{foffset}, next {next_offset}, implied size: {fsize}, ' + f'size bounds: [{f_min_size}, {f_max_size}]') + field_values[fkey] = ftyp.deserialize(stream, fsize) + return cls(**field_values) # type: ignore + + def serialize(self, stream: BinaryIO) -> int: + active_fields = self.active_fields() + num_prefix_bytes = active_fields.serialize(stream) + + num_data_bytes = 0 + has_dyn_fields = False + for (findex, ftyp) in self.__class__._field_indices.values(): + if not active_fields.get(findex): + continue + if ftyp.is_fixed_byte_length(): + num_data_bytes += ftyp.type_byte_length() + else: + num_data_bytes += OFFSET_BYTE_LENGTH + has_dyn_fields = True + + if has_dyn_fields: + temp_dyn_stream = io.BytesIO() + data = super().get_backing().get_left() + for (findex, ftyp) in self.__class__._field_indices.values(): + if not active_fields.get(findex): + continue + fnode = data.getter(2**get_depth(self.__class__.N) + findex) + v = ftyp.view_from_backing(fnode) + if ftyp.is_fixed_byte_length(): + v.serialize(stream) + else: + encode_offset(stream, num_data_bytes) + num_data_bytes += v.serialize(temp_dyn_stream) # type: ignore + if has_dyn_fields: + temp_dyn_stream.seek(0) + stream.write(temp_dyn_stream.read()) + + return num_prefix_bytes + num_data_bytes + + @classmethod + def navigate_type(cls, key: Any) -> Type[View]: + if key == '__active_fields__': + return Bitvector[cls.N] + (_, ftyp) = cls._field_indices[key] + return Optional[ftyp] + + @classmethod + def key_to_static_gindex(cls, key: Any) -> Gindex: + if key == '__active_fields__': + return RIGHT_GINDEX + (findex, _) = cls._field_indices[key] + return 2**get_depth(cls.N) * 2 + findex + + +class Profile(ComplexView): + __slots__ = '_field_indices', '_o', 'B' + _field_indices: Dict[str, Tuple[int, Type[View], bool]] + _o: int + B: Type[StableContainer] + + def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None, **kwargs): + if backing is not None: + if len(kwargs) != 0: + raise Exception('Cannot have both a backing and elements to init fields') + return super().__new__(cls, backing=backing, hook=hook, **kwargs) + + extra_kw = kwargs.copy() + for fkey, (_, _, fopt) in cls._field_indices.items(): + if fkey in extra_kw: + extra_kw.pop(fkey) + elif not fopt: + raise AttributeError(f'Field `{fkey}` is required in {cls.__name__}') + else: + pass + if len(extra_kw) > 0: + raise AttributeError(f'Fields [{"".join(extra_kw.keys())}] unknown in `{cls.__name__}`') + + value = cls.B(backing, hook, **kwargs) + return cls(backing=value.get_backing()) + + def __init_subclass__(cls, **kwargs): + if 'b' not in kwargs: + raise TypeError(f'Missing base type: `{cls.__name__}(Profile)`') + b = kwargs.pop('b') + if not issubclass(b, StableContainer): + raise TypeError(f'Invalid base type: `Profile[{b.__name__}]`') + cls.B = b + + def __class_getitem__(cls, b) -> Type['Profile']: + def has_compatible_merkleization(ftyp, ftyp_base) -> bool: + if ftyp == ftyp_base: + return True + if issubclass(ftyp, boolean): + return issubclass(ftyp_base, boolean) + if issubclass(ftyp, uint8): + return issubclass(ftyp_base, uint8) + if issubclass(ftyp, uint16): + return issubclass(ftyp_base, uint16) + if issubclass(ftyp, uint32): + return issubclass(ftyp_base, uint32) + if issubclass(ftyp, uint64): + return issubclass(ftyp_base, uint64) + if issubclass(ftyp, uint128): + return issubclass(ftyp_base, uint128) + if issubclass(ftyp, uint256): + return issubclass(ftyp_base, uint256) + if issubclass(ftyp, Bitlist): + return ( + issubclass(ftyp_base, Bitlist) + and ftyp.limit() == ftyp_base.limit() + ) + if issubclass(ftyp, Bitvector): + return ( + issubclass(ftyp_base, Bitvector) + and ftyp.vector_length() == ftyp_base.vector_length() + ) + if issubclass(ftyp, ByteList): + if issubclass(ftyp_base, ByteList): + return ftyp.limit() == ftyp_base.limit() + return ( + issubclass(ftyp_base, List) + and ftyp.limit() == ftyp_base.limit() + and issubclass(ftyp_base.element_cls(), uint8) + ) + if issubclass(ftyp, ByteVector): + if issubclass(ftyp_base, ByteVector): + return ftyp.vector_length() == ftyp_base.vector_length() + return ( + issubclass(ftyp_base, Vector) + and ftyp.vector_length() == ftyp_base.vector_length() + and issubclass(ftyp_base.element_cls(), uint8) + ) + if issubclass(ftyp, List): + if issubclass(ftyp_base, ByteList): + return ( + ftyp.limit() == ftyp_base.limit() + and issubclass(ftyp.element_cls(), uint8) + ) + return ( + issubclass(ftyp_base, List) + and ftyp.limit() == ftyp_base.limit() + and has_compatible_merkleization(ftyp.element_cls(), ftyp_base.element_cls()) + ) + if issubclass(ftyp, Vector): + if issubclass(ftyp_base, ByteVector): + return ( + ftyp.vector_length() == ftyp_base.vector_length() + and issubclass(ftyp.element_cls(), uint8) + ) + return ( + issubclass(ftyp_base, Vector) + and ftyp.vector_length() == ftyp_base.vector_length() + and has_compatible_merkleization(ftyp.element_cls(), ftyp_base.element_cls()) + ) + if issubclass(ftyp, Container): + if not issubclass(ftyp_base, Container): + return False + fields = ftyp.fields() + fields_base = ftyp_base.fields() + if len(fields) != len(fields_base): + return False + for (fkey, t), (fkey_b, t_b) in zip(fields.items(), fields_base.items()): + if fkey != fkey_b: + return False + if not has_compatible_merkleization(t, t_b): + return False + return True + if issubclass(ftyp, StableContainer): + if not issubclass(ftyp_base, StableContainer): + return False + if ftyp.N != ftyp_base.N: + return False + fields = ftyp.fields() + fields_base = ftyp_base.fields() + if len(fields) != len(fields_base): + return False + for (fkey, t), (fkey_b, t_b) in zip(fields.items(), fields_base.items()): + if fkey != fkey_b: + return False + if not has_compatible_merkleization(t, t_b): + return False + return True + if issubclass(ftyp, Profile): + if issubclass(ftyp_base, StableContainer): + return has_compatible_merkleization(ftyp.B, ftyp_base) + if not issubclass(ftyp_base, Profile): + return False + if not has_compatible_merkleization(ftyp.B, ftyp_base.B): + return False + fields = ftyp.fields() + fields_base = ftyp_base.fields() + if len(fields) != len(fields_base): + return False + for (fkey, (t, _)), (fkey_b, (t_b, _)) in zip(fields.items(), fields_base.items()): + if fkey != fkey_b: + return False + if not has_compatible_merkleization(t, t_b): + return False + return True + return False + + class ProfileMeta(ViewMeta): + def __new__(cls, name, bases, dct): + return super().__new__(cls, name, bases, dct, b=b) + + class ProfileView(Profile, metaclass=ProfileMeta): + def __init_subclass__(cls, **kwargs): + if 'B' in cls.__dict__: + raise TypeError(f'Cannot override `B` inside `{cls.__name__}`') + cls._field_indices = {} + cls._o = 0 + last_findex = -1 + for (fkey, t) in cls.__annotations__.items(): + if fkey not in cls.B._field_indices: + raise TypeError( + f'`{cls.__name__}` fields must exist in the base type ' + f'but `{fkey}` is not defined in `{cls.B.__name__}`' + ) + (findex, ftyp) = cls.B._field_indices[fkey] + if findex <= last_findex: + raise TypeError( + f'`{cls.__name__}` fields must have the same order as in the base type ' + f'but `{fkey}` is defined earlier than in `{cls.B.__name__}`' + ) + last_findex = findex + fopt = ( + get_origin(t) == PyUnion + and len(get_args(t)) == 2 + and type(None) in get_args(t) + ) + if fopt: + t = get_args(t)[0] if get_args(t)[0] is not type(None) else get_args(t)[1] + if not has_compatible_merkleization(t, ftyp): + raise TypeError( + f'`{cls.__name__}.{fkey}` has type `{t.__name__}`, incompatible ' + f'with base field `{cls.B.__name__}.{fkey}` of type `{ftyp.__name__}`' + ) + cls._field_indices[fkey] = (findex, t, fopt) + if fopt: + cls._o += 1 + + ProfileView.__name__ = ProfileView.type_repr() + return ProfileView + + @classmethod + def fields(cls) -> Dict[str, Tuple[Type[View], bool]]: + return { fkey: (ftyp, fopt) for fkey, (_, ftyp, fopt) in cls._field_indices.items() } + + @classmethod + def is_fixed_byte_length(cls) -> bool: + if cls._o > 0: + return False + for (_, ftyp, _) in cls._field_indices.values(): + if not ftyp.is_fixed_byte_length(): + return False + return True + + @classmethod + def type_byte_length(cls) -> int: + if cls.is_fixed_byte_length(): + return cls.min_byte_length() + else: + raise Exception(f'Dynamic length `Profile` does not have a fixed byte length') + + @classmethod + def min_byte_length(cls) -> int: + total = Bitvector[cls._o].type_byte_length() if cls._o > 0 else 0 + for (_, ftyp, fopt) in cls._field_indices.values(): + if fopt: + continue + if not ftyp.is_fixed_byte_length(): + total += OFFSET_BYTE_LENGTH + total += ftyp.min_byte_length() + return total + + @classmethod + def max_byte_length(cls) -> int: + total = Bitvector[cls._o].type_byte_length() if cls._o > 0 else 0 + for (_, ftyp, _) in cls._field_indices.values(): + if not ftyp.is_fixed_byte_length(): + total += OFFSET_BYTE_LENGTH + total += ftyp.max_byte_length() + return total + + @classmethod + def is_packed(cls) -> bool: + return False + + @classmethod + def tree_depth(cls) -> int: + return cls.B.tree_depth() + + @classmethod + def item_elem_cls(cls, i: int) -> Type[View]: + return cls.B.item_elem_cls(i) + + def active_fields(self) -> Bitvector: + active_fields_node = super().get_backing().get_right() + return Bitvector[self.__class__.B.N].view_from_backing(active_fields_node) + + def optional_fields(self) -> Bitvector: + if self.__class__._o == 0: + raise Exception(f'`{self.__class__.__name__}` does not have any `Optional[T]` fields') + active_fields = self.active_fields() + optional_fields = Bitvector[self.__class__._o]() + oindex = 0 + for (findex, _, fopt) in self.__class__._field_indices.values(): + if fopt: + optional_fields.set(oindex, active_fields.get(findex)) + oindex += 1 + return optional_fields + + def __getattribute__(self, item): + if item == 'B': + raise AttributeError(f'Use `.__class__.{item}` to access `{item}`') + return object.__getattribute__(self, item) + + def __getattr__(self, item): + if item[0] == '_': + return super().__getattribute__(item) + else: + try: + (findex, ftyp, fopt) = self.__class__._field_indices[item] + except KeyError: + raise AttributeError(f'Unknown field `{item}`') + + value = stable_get(self, findex, ftyp, self.__class__.B.N) + assert value is not None or fopt + return value + + def __setattr__(self, key, value): + if key[0] == '_': + super().__setattr__(key, value) + else: + try: + (findex, ftyp, fopt) = self.__class__._field_indices[key] + except KeyError: + raise AttributeError(f'Unknown field `{key}`') + + if value is None and not fopt: + raise ValueError(f'Field `{key}` is required and cannot be set to `None`') + stable_set(self, findex, ftyp, self.__class__.B.N, value) + + def __repr__(self): + return f'{self.__class__.type_repr()}:\n' + '\n'.join( + indent(field_val_repr(self, fkey, ftyp, fopt), ' ') + for fkey, (_, ftyp, fopt) in self.__class__._field_indices.items()) + + @classmethod + def type_repr(cls) -> str: + return f'Profile[{cls.B.__name__}]' + + @classmethod + def deserialize(cls: Type[BV], stream: BinaryIO, scope: int) -> BV: + if cls._o > 0: + num_prefix_bytes = Bitvector[cls._o].type_byte_length() + if scope < num_prefix_bytes: + raise ValueError(f'Scope too small for `Profile[{cls.B.__name__}]` optional fields') + optional_fields = Bitvector[cls._o].deserialize(stream, num_prefix_bytes) + scope = scope - num_prefix_bytes + + field_values: Dict[str, Optional[View]] = {} + dyn_fields: PyList[FieldOffset] = [] + fixed_size = 0 + oindex = 0 + for fkey, (_, ftyp, fopt) in cls._field_indices.items(): + if fopt: + has_field = optional_fields.get(oindex) + oindex += 1 + if not has_field: + field_values[fkey] = None + continue + if ftyp.is_fixed_byte_length(): + fsize = ftyp.type_byte_length() + field_values[fkey] = ftyp.deserialize(stream, fsize) + fixed_size += fsize + else: + dyn_fields.append(FieldOffset( + key=fkey, typ=ftyp, offset=int(decode_offset(stream)))) + fixed_size += OFFSET_BYTE_LENGTH + assert oindex == cls._o + if len(dyn_fields) > 0: + if dyn_fields[0].offset < fixed_size: + raise Exception(f'First offset {dyn_fields[0].offset} is ' + f'smaller than expected fixed size {fixed_size}') + for i, (fkey, ftyp, foffset) in enumerate(dyn_fields): + next_offset = dyn_fields[i + 1].offset if i + 1 < len(dyn_fields) else scope + if foffset > next_offset: + raise Exception(f'Offset {i} is invalid: {foffset} ' + f'larger than next offset {next_offset}') + fsize = next_offset - foffset + f_min_size, f_max_size = ftyp.min_byte_length(), ftyp.max_byte_length() + if not (f_min_size <= fsize <= f_max_size): + raise Exception(f'Offset {i} is invalid, size out of bounds: ' + f'{foffset}, next {next_offset}, implied size: {fsize}, ' + f'size bounds: [{f_min_size}, {f_max_size}]') + field_values[fkey] = ftyp.deserialize(stream, fsize) + return cls(**field_values) # type: ignore + + def serialize(self, stream: BinaryIO) -> int: + if self.__class__._o > 0: + optional_fields = self.optional_fields() + num_prefix_bytes = optional_fields.serialize(stream) + else: + num_prefix_bytes = 0 + + num_data_bytes = 0 + has_dyn_fields = False + oindex = 0 + for (_, ftyp, fopt) in self.__class__._field_indices.values(): + if fopt: + has_field = optional_fields.get(oindex) + oindex += 1 + if not has_field: + continue + if ftyp.is_fixed_byte_length(): + num_data_bytes += ftyp.type_byte_length() + else: + num_data_bytes += OFFSET_BYTE_LENGTH + has_dyn_fields = True + assert oindex == self.__class__._o + + if has_dyn_fields: + temp_dyn_stream = io.BytesIO() + data = super().get_backing().get_left() + active_fields = self.active_fields() + n = self.__class__.B.N + for (findex, ftyp, _) in self.__class__._field_indices.values(): + if not active_fields.get(findex): + continue + fnode = data.getter(2**get_depth(n) + findex) + v = ftyp.view_from_backing(fnode) + if ftyp.is_fixed_byte_length(): + v.serialize(stream) + else: + encode_offset(stream, num_data_bytes) + num_data_bytes += v.serialize(temp_dyn_stream) # type: ignore + if has_dyn_fields: + temp_dyn_stream.seek(0) + stream.write(temp_dyn_stream.read(num_data_bytes)) + + return num_prefix_bytes + num_data_bytes + + @classmethod + def navigate_type(cls, key: Any) -> Type[View]: + if key == '__active_fields__': + return Bitvector[cls.B.N] + (_, ftyp, fopt) = cls._field_indices[key] + return Optional[ftyp] if fopt else ftyp + + @classmethod + def key_to_static_gindex(cls, key: Any) -> Gindex: + if key == '__active_fields__': + return RIGHT_GINDEX + (findex, _, _) = cls._field_indices[key] + return 2**get_depth(cls.B.N) * 2 + findex diff --git a/remerkleable/test_impl.py b/remerkleable/test_impl.py index 280b446..966bfb2 100644 --- a/remerkleable/test_impl.py +++ b/remerkleable/test_impl.py @@ -2,13 +2,14 @@ # flake8:noqa E501 Ignore long lines, some test cases are just inherently long -from typing import Iterable, Type +from typing import Iterable, Optional, Type import io from remerkleable.complex import Container, Vector, List from remerkleable.basic import boolean, bit, byte, uint8, uint16, uint32, uint64, uint128, uint256 from remerkleable.bitfields import Bitvector, Bitlist from remerkleable.byte_arrays import ByteVector, ByteList from remerkleable.core import View, ObjType +from remerkleable.stable_container import Profile, StableContainer from remerkleable.union import Union from hashlib import sha256 @@ -475,3 +476,402 @@ class B(Container): assert A(1, 2, 3) != B(1, 2, 3, 0) assert A(1, 2, 3) in {A(1, 2, 3)} assert A(1, 2, 3) not in {B(1, 2, 3, 0)} + + +def test_stable_container(): + class Shape(StableContainer[4]): + side: Optional[uint16] + color: Optional[uint8] + radius: Optional[uint16] + + class Square(Profile[Shape]): + side: uint16 + color: uint8 + + class Circle(Profile[Shape]): + color: uint8 + radius: uint16 + + class ShapePair(Container): + shape_1: Shape + shape_2: Shape + + class SquarePair(Container): + shape_1: Square + shape_2: Square + + class CirclePair(Container): + shape_1: Circle + shape_2: Circle + + # Helper containers for merkleization testing + class ShapePayload(Container): + side: uint16 + color: uint8 + radius: uint16 + + class ShapeRepr(Container): + value: ShapePayload + active_fields: Bitvector[4] + + class ShapePairRepr(Container): + shape_1: ShapeRepr + shape_2: ShapeRepr + + # Square tests + square_bytes_stable = bytes.fromhex("03420001") + square_bytes_profile = bytes.fromhex("420001") + square_root = ShapeRepr( + value=ShapePayload(side=0x42, color=1, radius=0), + active_fields=Bitvector[4](True, True, False, False), + ).hash_tree_root() + shapes = [Shape(side=0x42, color=1, radius=None)] + squares = [Square(side=0x42, color=1)] + squares.extend(list(Square(backing=shape.get_backing()) for shape in shapes)) + shapes.extend(list(Shape(backing=shape.get_backing()) for shape in shapes)) + shapes.extend(list(Shape(backing=square.get_backing()) for square in squares)) + squares.extend(list(Square(backing=square.get_backing()) for square in squares)) + assert len(set(shapes)) == 1 + assert len(set(squares)) == 1 + assert all(shape.encode_bytes() == square_bytes_stable for shape in shapes) + assert all(square.encode_bytes() == square_bytes_profile for square in squares) + assert ( + Square(backing=Shape.decode_bytes(square_bytes_stable).get_backing()) == + Square.decode_bytes(square_bytes_profile) + ) + assert all(shape.hash_tree_root() == square_root for shape in shapes) + assert all(square.hash_tree_root() == square_root for square in squares) + try: + circle = Circle(side=0x42, color=1) + assert False + except: + pass + for shape in shapes: + try: + circle = Circle(backing=shape.get_backing()) + assert False + except: + pass + for square in squares: + try: + circle = Circle(backing=square.get_backing()) + assert False + except: + pass + for shape in shapes: + shape.side = 0x1337 + for square in squares: + square.side = 0x1337 + square_bytes_stable = bytes.fromhex("03371301") + square_bytes_profile = bytes.fromhex("371301") + square_root = ShapeRepr( + value=ShapePayload(side=0x1337, color=1, radius=0), + active_fields=Bitvector[4](True, True, False, False), + ).hash_tree_root() + assert len(set(shapes)) == 1 + assert len(set(squares)) == 1 + assert all(shape.encode_bytes() == square_bytes_stable for shape in shapes) + assert all(square.encode_bytes() == square_bytes_profile for square in squares) + assert ( + Square(backing=Shape.decode_bytes(square_bytes_stable).get_backing()) == + Square.decode_bytes(square_bytes_profile) + ) + assert all(shape.hash_tree_root() == square_root for shape in shapes) + assert all(square.hash_tree_root() == square_root for square in squares) + for square in squares: + try: + square.radius = 0x1337 + assert False + except: + pass + for square in squares: + try: + square.side = None + assert False + except: + pass + + # Circle tests + circle_bytes_stable = bytes.fromhex("06014200") + circle_bytes_profile = bytes.fromhex("014200") + circle_root = ShapeRepr( + value=ShapePayload(side=0, color=1, radius=0x42), + active_fields=Bitvector[4](False, True, True, False), + ).hash_tree_root() + modified_shape = shapes[0] + modified_shape.side = None + modified_shape.radius = 0x42 + shapes = [Shape(side=None, color=1, radius=0x42), modified_shape] + circles = [Circle(radius=0x42, color=1)] + circles.extend(list(Circle(backing=shape.get_backing()) for shape in shapes)) + shapes.extend(list(Shape(backing=shape.get_backing()) for shape in shapes)) + shapes.extend(list(Shape(backing=circle.get_backing()) for circle in circles)) + circles.extend(list(Circle(backing=circle.get_backing()) for circle in circles)) + assert len(set(shapes)) == 1 + assert len(set(circles)) == 1 + assert all(shape.encode_bytes() == circle_bytes_stable for shape in shapes) + assert all(circle.encode_bytes() == circle_bytes_profile for circle in circles) + assert ( + Circle(backing=Shape.decode_bytes(circle_bytes_stable).get_backing()) == + Circle.decode_bytes(circle_bytes_profile) + ) + assert all(shape.hash_tree_root() == circle_root for shape in shapes) + assert all(circle.hash_tree_root() == circle_root for circle in circles) + try: + square = Square(radius=0x42, color=1) + assert False + except: + pass + for shape in shapes: + try: + square = Square(backing=shape.get_backing()) + assert False + except: + pass + for circle in circles: + try: + square = Square(backing=circle.get_backing()) + assert False + except: + pass + + # SquarePair tests + square_pair_bytes_stable = bytes.fromhex("080000000c0000000342000103690001") + square_pair_bytes_profile = bytes.fromhex("420001690001") + square_pair_root = ShapePairRepr( + shape_1=ShapeRepr( + value=ShapePayload(side=0x42, color=1, radius=0), + active_fields=Bitvector[4](True, True, False, False), + ), + shape_2=ShapeRepr( + value=ShapePayload(side=0x69, color=1, radius=0), + active_fields=Bitvector[4](True, True, False, False), + ) + ).hash_tree_root() + shape_pairs = [ShapePair( + shape_1=Shape(side=0x42, color=1, radius=None), + shape_2=Shape(side=0x69, color=1, radius=None), + )] + square_pairs = [SquarePair( + shape_1=Square(side=0x42, color=1), + shape_2=Square(side=0x69, color=1), + )] + square_pairs.extend(list(SquarePair(backing=pair.get_backing()) for pair in shape_pairs)) + shape_pairs.extend(list(ShapePair(backing=pair.get_backing()) for pair in shape_pairs)) + shape_pairs.extend(list(ShapePair(backing=pair.get_backing()) for pair in square_pairs)) + square_pairs.extend(list(SquarePair(backing=pair.get_backing()) for pair in square_pairs)) + assert len(set(shape_pairs)) == 1 + assert len(set(square_pairs)) == 1 + assert all(pair.encode_bytes() == square_pair_bytes_stable for pair in shape_pairs) + assert all(pair.encode_bytes() == square_pair_bytes_profile for pair in square_pairs) + assert ( + SquarePair(backing=ShapePair.decode_bytes(square_pair_bytes_stable).get_backing()) == + SquarePair.decode_bytes(square_pair_bytes_profile) + ) + assert all(pair.hash_tree_root() == square_pair_root for pair in shape_pairs) + assert all(pair.hash_tree_root() == square_pair_root for pair in square_pairs) + + # CirclePair tests + circle_pair_bytes_stable = bytes.fromhex("080000000c0000000601420006016900") + circle_pair_bytes_profile = bytes.fromhex("014200016900") + circle_pair_root = ShapePairRepr( + shape_1=ShapeRepr( + value=ShapePayload(side=0, color=1, radius=0x42), + active_fields=Bitvector[4](False, True, True, False), + ), + shape_2=ShapeRepr( + value=ShapePayload(side=0, color=1, radius=0x69), + active_fields=Bitvector[4](False, True, True, False), + ) + ).hash_tree_root() + shape_pairs = [ShapePair( + shape_1=Shape(side=None, color=1, radius=0x42), + shape_2=Shape(side=None, color=1, radius=0x69), + )] + circle_pairs = [CirclePair( + shape_1=Circle(radius=0x42, color=1), + shape_2=Circle(radius=0x69, color=1), + )] + circle_pairs.extend(list(CirclePair(backing=pair.get_backing()) for pair in shape_pairs)) + shape_pairs.extend(list(ShapePair(backing=pair.get_backing()) for pair in shape_pairs)) + shape_pairs.extend(list(ShapePair(backing=pair.get_backing()) for pair in circle_pairs)) + circle_pairs.extend(list(CirclePair(backing=pair.get_backing()) for pair in circle_pairs)) + assert len(set(shape_pairs)) == 1 + assert len(set(circle_pairs)) == 1 + assert all(pair.encode_bytes() == circle_pair_bytes_stable for pair in shape_pairs) + assert all(pair.encode_bytes() == circle_pair_bytes_profile for pair in circle_pairs) + assert ( + CirclePair(backing=ShapePair.decode_bytes(circle_pair_bytes_stable).get_backing()) == + CirclePair.decode_bytes(circle_pair_bytes_profile) + ) + assert all(pair.hash_tree_root() == circle_pair_root for pair in shape_pairs) + assert all(pair.hash_tree_root() == circle_pair_root for pair in circle_pairs) + + # Unsupported tests + shape = Shape(side=None, color=1, radius=None) + shape_bytes = bytes.fromhex("0201") + assert shape.encode_bytes() == shape_bytes + assert Shape.decode_bytes(shape_bytes) == shape + try: + shape = Square.decode_bytes(shape_bytes) + assert False + except: + pass + try: + shape = Circle.decode_bytes(shape_bytes) + assert False + except: + pass + shape = Shape(side=0x42, color=1, radius=0x42) + shape_bytes = bytes.fromhex("074200014200") + assert shape.encode_bytes() == shape_bytes + assert Shape.decode_bytes(shape_bytes) == shape + try: + shape = Square.decode_bytes(shape_bytes) + assert False + except: + pass + try: + shape = Circle.decode_bytes(shape_bytes) + assert False + except: + pass + try: + shape = Shape.decode_bytes("00") + assert False + except: + pass + try: + square = Square(radius=0x42, color=1) + assert False + except: + pass + try: + circle = Circle(side=0x42, color=1) + assert False + except: + pass + + # Surrounding container tests + class ShapeContainer(Container): + shape: Shape + square: Square + circle: Circle + + class ShapeContainerRepr(Container): + shape: ShapeRepr + square: ShapeRepr + circle: ShapeRepr + + container = ShapeContainer( + shape=Shape(side=0x42, color=1, radius=0x42), + square=Square(side=0x42, color=1), + circle=Circle(radius=0x42, color=1), + ) + container_bytes = bytes.fromhex("0a000000420001014200074200014200") + assert container.encode_bytes() == container_bytes + assert ShapeContainer.decode_bytes(container_bytes) == container + assert container.hash_tree_root() == ShapeContainerRepr( + shape=ShapeRepr( + value=ShapePayload(side=0x42, color=1, radius=0x42), + active_fields=Bitvector[4](True, True, True, False), + ), + square=ShapeRepr( + value=ShapePayload(side=0x42, color=1, radius=0), + active_fields=Bitvector[4](True, True, False, False), + ), + circle=ShapeRepr( + value=ShapePayload(side=0, color=1, radius=0x42), + active_fields=Bitvector[4](False, True, True, False), + ), + ).hash_tree_root() + + # basic container + class Shape1(StableContainer[4]): + side: Optional[uint16] + color: Optional[uint8] + radius: Optional[uint16] + + # basic container with different depth + class Shape2(StableContainer[8]): + side: Optional[uint16] + color: Optional[uint8] + radius: Optional[uint16] + + # basic container with variable fields + class Shape3(StableContainer[8]): + side: Optional[uint16] + colors: Optional[List[uint8, 4]] + radius: Optional[uint16] + + stable_container_tests = [ + { + 'value': Shape1(side=0x42, color=1, radius=0x42), + 'serialized': '074200014200', + 'hash_tree_root': '37b28eab19bc3e246e55d2e2b2027479454c27ee006d92d4847c84893a162e6d' + }, + { + 'value': Shape1(side=0x42, color=1, radius=None), + 'serialized': '03420001', + 'hash_tree_root': 'bfdb6fda9d02805e640c0f5767b8d1bb9ff4211498a5e2d7c0f36e1b88ce57ff' + }, + { + 'value': Shape1(side=None, color=1, radius=None), + 'serialized': '0201', + 'hash_tree_root': '522edd7309c0041b8eb6a218d756af558e9cf4c816441ec7e6eef42dfa47bb98' + }, + { + 'value': Shape1(side=None, color=1, radius=0x42), + 'serialized': '06014200', + 'hash_tree_root': 'f66d2c38c8d2afbd409e86c529dff728e9a4208215ca20ee44e49c3d11e145d8' + }, + { + 'value': Shape2(side=0x42, color=1, radius=0x42), + 'serialized': '074200014200', + 'hash_tree_root': '0792fb509377ee2ff3b953dd9a88eee11ac7566a8df41c6c67a85bc0b53efa4e' + }, + { + 'value': Shape2(side=0x42, color=1, radius=None), + 'serialized': '03420001', + 'hash_tree_root': 'ddc7acd38ae9d6d6788c14bd7635aeb1d7694768d7e00e1795bb6d328ec14f28' + }, + { + 'value': Shape2(side=None, color=1, radius=None), + 'serialized': '0201', + 'hash_tree_root': '9893ecf9b68030ff23c667a5f2e4a76538a8e2ab48fd060a524888a66fb938c9' + }, + { + 'value': Shape2(side=None, color=1, radius=0x42), + 'serialized': '06014200', + 'hash_tree_root': 'e823471310312d52aa1135d971a3ed72ba041ade3ec5b5077c17a39d73ab17c5' + }, + { + 'value': Shape3(side=0x42, colors=[1, 2], radius=0x42), + 'serialized': '0742000800000042000102', + 'hash_tree_root': '1093b0f1d88b1b2b458196fa860e0df7a7dc1837fe804b95d664279635cb302f' + }, + { + 'value': Shape3(side=0x42, colors=None, radius=None), + 'serialized': '014200', + 'hash_tree_root': '28df3f1c3eebd92504401b155c5cfe2f01c0604889e46ed3d22a3091dde1371f' + }, + { + 'value': Shape3(side=None, colors=[1, 2], radius=None), + 'serialized': '02040000000102', + 'hash_tree_root': '659638368467b2c052ca698fcb65902e9b42ce8e94e1f794dd5296ceac2dec3e' + }, + { + 'value': Shape3(side=None, colors=None, radius=0x42), + 'serialized': '044200', + 'hash_tree_root': 'd585dd0561c718bf4c29e4c1bd7d4efd4a5fe3c45942a7f778acb78fd0b2a4d2' + }, + { + 'value': Shape3(side=None, colors=[1, 2], radius=0x42), + 'serialized': '060600000042000102', + 'hash_tree_root': '00fc0cecc200a415a07372d5d5b8bc7ce49f52504ed3da0336f80a26d811c7bf' + } + ] + + for test in stable_container_tests: + assert test['value'].encode_bytes().hex() == test['serialized'] + assert test['value'].hash_tree_root().hex() == test['hash_tree_root'] diff --git a/remerkleable/test_typing.py b/remerkleable/test_typing.py index 6862414..0094349 100644 --- a/remerkleable/test_typing.py +++ b/remerkleable/test_typing.py @@ -2,6 +2,7 @@ import pytest # type: ignore +from typing import Optional from random import Random from remerkleable.complex import Container, Vector, List @@ -10,6 +11,7 @@ from remerkleable.bitfields import Bitvector, Bitlist from remerkleable.byte_arrays import ByteVector, Bytes1, Bytes4, Bytes8, Bytes32, Bytes48, Bytes96 from remerkleable.core import BasicView, View +from remerkleable.stable_container import Profile, StableContainer from remerkleable.union import Union from remerkleable.tree import get_depth, merkle_hash, LEFT_GINDEX, RIGHT_GINDEX @@ -362,6 +364,59 @@ class Wrapper(Container): except KeyError: pass + class StableFields(StableContainer[8]): + foo: Optional[uint32] + bar: Optional[uint64] + quix: Optional[uint64] + more: Optional[uint32] + + class FooFields(Profile[StableFields]): + foo: uint32 + more: Optional[uint32] + + class BarFields(Profile[StableFields]): + bar: uint64 + quix: uint64 + more: Optional[uint32] + + assert issubclass((StableFields / '__active_fields__').navigate_type(), Bitvector) + assert (StableFields / '__active_fields__').navigate_type().vector_length() == 8 + assert (StableFields / '__active_fields__').gindex() == 0b11 + assert (StableFields / 'foo').navigate_type() == Optional[uint32] + assert (StableFields / 'foo').gindex() == 0b10000 + assert (StableFields / 'bar').navigate_type() == Optional[uint64] + assert (StableFields / 'bar').gindex() == 0b10001 + assert (StableFields / 'quix').navigate_type() == Optional[uint64] + assert (StableFields / 'quix').gindex() == 0b10010 + assert (StableFields / 'more').navigate_type() == Optional[uint32] + assert (StableFields / 'more').gindex() == 0b10011 + + assert issubclass((FooFields / '__active_fields__').navigate_type(), Bitvector) + assert (FooFields / '__active_fields__').navigate_type().vector_length() == 8 + assert (FooFields / '__active_fields__').gindex() == 0b11 + assert (FooFields / 'foo').navigate_type() == uint32 + assert (FooFields / 'foo').gindex() == 0b10000 + assert (FooFields / 'more').navigate_type() == Optional[uint32] + assert (FooFields / 'more').gindex() == 0b10011 + try: + (FooFields / 'bar').navigate_type() + assert False + except KeyError: + pass + + assert issubclass((BarFields / '__active_fields__').navigate_type(), Bitvector) + assert (BarFields / '__active_fields__').navigate_type().vector_length() == 8 + assert (BarFields / '__active_fields__').gindex() == 0b11 + assert (BarFields / 'bar').navigate_type() == uint64 + assert (BarFields / 'bar').gindex() == 0b10001 + assert (BarFields / 'more').navigate_type() == Optional[uint32] + assert (BarFields / 'more').gindex() == 0b10011 + try: + (BarFields / 'foo').navigate_type() + assert False + except KeyError: + pass + def test_bitvector(): for size in [1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 511, 512, 513, 1023, 1024, 1025]: