From 12f63d8953dcdc466dd8ca58c50c5ae4b4307fa0 Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Mon, 4 Mar 2024 13:44:17 +0100 Subject: [PATCH 01/23] add EIP-7495 implementation: `StableContainer` Introduce support for `StableContainer` so that it can be used in drafts for `ethereum/consensus-specs`. Marked as 'under review', and as such is subject to change. --- remerkleable/stable_container.py | 332 +++++++++++++++++++++++++++++++ remerkleable/test_impl.py | 217 +++++++++++++++++++- 2 files changed, 548 insertions(+), 1 deletion(-) create mode 100644 remerkleable/stable_container.py diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py new file mode 100644 index 0000000..cba54b9 --- /dev/null +++ b/remerkleable/stable_container.py @@ -0,0 +1,332 @@ +# This file implements `StableContainer` according to https://eips.ethereum.org/EIPS/eip-7495 +# The EIP is still under review, functionality may change or go away without deprecation. + +import io +from typing import BinaryIO, Dict, List as PyList, Optional, TypeVar, Type, Union as PyUnion, \ + get_args, get_origin +from textwrap import indent +from remerkleable.bitfields import Bitvector +from remerkleable.complex import ComplexView, Container, FieldOffset, \ + decode_offset, encode_offset +from remerkleable.core import View, ViewHook, OFFSET_BYTE_LENGTH +from remerkleable.tree import NavigationError, Node, PairNode, \ + get_depth, subtree_fill_to_contents, zero_node + +N = TypeVar('N') +S = TypeVar('S', bound="ComplexView") + + +class StableContainer(ComplexView): + _field_indices: Dict[str, tuple[int, Type[View], bool]] + __slots__ = '_field_indices' + + def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None, **kwargs): + if backing is not None: + if len(kwargs) != 0: + raise Exception("cannot have both a backing and elements to init fields") + return super().__new__(cls, backing=backing, hook=hook, **kwargs) + + for fkey, (ftyp, fopt) in cls.fields().items(): + if fkey not in kwargs: + if not fopt: + raise AttributeError(f"Field '{fkey}' is required in {cls}") + kwargs[fkey] = None + + input_nodes = [] + active_fields = Bitvector[cls.N]() + for findex, (fkey, (ftyp, fopt)) in enumerate(cls.fields().items()): + fnode: Node + assert fkey in kwargs + finput = kwargs.pop(fkey) + if finput is None: + fnode = zero_node(0) + active_fields.set(findex, False) + else: + if isinstance(finput, View): + fnode = finput.get_backing() + else: + fnode = ftyp.coerce_view(finput).get_backing() + active_fields.set(findex, True) + input_nodes.append(fnode) + + if len(kwargs) > 0: + raise AttributeError(f'The field names [{"".join(kwargs.keys())}] are not defined in {cls}') + + backing = PairNode( + left=subtree_fill_to_contents(input_nodes, get_depth(cls.N)), + right=active_fields.get_backing()) + return super().__new__(cls, backing=backing, hook=hook, **kwargs) + + def __init_subclass__(cls, *args, **kwargs): + super().__init_subclass__(*args, **kwargs) + cls._field_indices = { + fkey: (i, ftyp, fopt) + for i, (fkey, (ftyp, fopt)) in enumerate(cls.fields().items()) + } + + def __class_getitem__(cls, n) -> Type["StableContainer"]: + if n <= 0: + raise Exception(f"invalid stablecontainer capacity: {n}") + + class StableContainerView(StableContainer): + N = n + + StableContainerView.__name__ = StableContainerView.type_repr() + return StableContainerView + + @classmethod + def fields(cls) -> Dict[str, tuple[Type[View], bool]]: + fields = {} + for k, v in cls.__annotations__.items(): + fopt = get_origin(v) == PyUnion and type(None) in get_args(v) + ftyp = get_args(v)[0] if fopt else v + fields[k] = (ftyp, fopt) + return fields + + @classmethod + def is_fixed_byte_length(cls) -> bool: + return False + + @classmethod + def min_byte_length(cls) -> int: + total = Bitvector[cls.N].type_byte_length() + for _, (ftyp, fopt) in cls.fields().items(): + if fopt: + continue + if not ftyp.is_fixed_byte_length(): + total += OFFSET_BYTE_LENGTH + total += ftyp.min_byte_length() + return total + + @classmethod + def max_byte_length(cls) -> int: + total = Bitvector[cls.N].type_byte_length() + for _, (ftyp, _) in cls.fields().items(): + if not ftyp.is_fixed_byte_length(): + total += OFFSET_BYTE_LENGTH + total += ftyp.max_byte_length() + return total + + def active_fields(self) -> Bitvector: + active_fields_node = super().get_backing().get_right() + return Bitvector[self.__class__.N].view_from_backing(active_fields_node) + + def __getattr__(self, item): + if item[0] == '_': + return super().__getattribute__(item) + else: + try: + (findex, ftyp, fopt) = self.__class__._field_indices[item] + except KeyError: + raise AttributeError(f"unknown attribute {item}") + + if not self.active_fields().get(findex): + assert fopt + return None + + data = super().get_backing().get_left() + fnode = data.getter(2**get_depth(self.__class__.N) + findex) + return ftyp.view_from_backing(fnode) + + def __setattr__(self, key, value): + if key[0] == '_': + super().__setattr__(key, value) + else: + try: + (findex, ftyp, fopt) = self.__class__._field_indices[key] + except KeyError: + raise AttributeError(f"unknown attribute {key}") + + next_backing = self.get_backing() + + assert value is not None or fopt + active_fields = self.active_fields() + active_fields.set(findex, value is not None) + next_backing = next_backing.rebind_right(active_fields.get_backing()) + + if value is not None: + if isinstance(value, ftyp): + fnode = value.get_backing() + else: + fnode = ftyp.coerce_view(value).get_backing() + else: + fnode = zero_node(0) + data = next_backing.get_left() + next_data = data.setter(2**get_depth(self.__class__.N) + findex)(fnode) + next_backing = next_backing.rebind_left(next_data) + + self.set_backing(next_backing) + + def _get_field_val_repr(self, fkey: str, ftyp: Type[View], fopt: bool) -> str: + field_start = ' ' + fkey + ': ' + ( + ('Optional[' if fopt else '') + ftyp.__name__ + (']' if fopt else '') + ) + ' = ' + try: + field_repr = repr(getattr(self, fkey)) + if '\n' in field_repr: # if multiline, indent it, but starting from the value. + i = field_repr.index('\n') + field_repr = field_repr[:i+1] + indent(field_repr[i+1:], ' ' * len(field_start)) + return field_start + field_repr + except NavigationError: + return f"{field_start} *omitted from partial*" + + def __repr__(self): + return f"{self.__class__.type_repr()}:\n" + '\n'.join( + indent(self._get_field_val_repr(fkey, ftyp, fopt), ' ') + for fkey, (ftyp, fopt) in self.__class__.fields().items()) + + @classmethod + def type_repr(cls) -> str: + return f"StableContainer[{cls.N}]" + + @classmethod + def deserialize(cls: Type[S], stream: BinaryIO, scope: int) -> S: + num_prefix_bytes = Bitvector[cls.N].type_byte_length() + if scope < num_prefix_bytes: + raise ValueError("scope too small, cannot read StableContainer active fields") + active_fields = Bitvector[cls.N].deserialize(stream, num_prefix_bytes) + scope = scope - num_prefix_bytes + + max_findex = 0 + field_values: Dict[str, Optional[View]] = {} + dyn_fields: PyList[FieldOffset] = [] + fixed_size = 0 + for findex, (fkey, (ftyp, _)) in enumerate(cls.fields().items()): + max_findex = findex + if not active_fields.get(findex): + field_values[fkey] = None + continue + if ftyp.is_fixed_byte_length(): + fsize = ftyp.type_byte_length() + field_values[fkey] = ftyp.deserialize(stream, fsize) + fixed_size += fsize + else: + dyn_fields.append(FieldOffset( + key=fkey, typ=ftyp, offset=int(decode_offset(stream)))) + fixed_size += OFFSET_BYTE_LENGTH + if len(dyn_fields) > 0: + if dyn_fields[0].offset < fixed_size: + raise Exception(f"first offset {dyn_fields[0].offset} is " + f"smaller than expected fixed size {fixed_size}") + for i, (fkey, ftyp, foffset) in enumerate(dyn_fields): + next_offset = dyn_fields[i + 1].offset if i + 1 < len(dyn_fields) else scope + if foffset > next_offset: + raise Exception(f"offset {i} is invalid: {foffset} " + f"larger than next offset {next_offset}") + fsize = next_offset - foffset + f_min_size, f_max_size = ftyp.min_byte_length(), ftyp.max_byte_length() + if not (f_min_size <= fsize <= f_max_size): + raise Exception(f"offset {i} is invalid, size out of bounds: " + f"{foffset}, next {next_offset}, implied size: {fsize}, " + f"size bounds: [{f_min_size}, {f_max_size}]") + field_values[fkey] = ftyp.deserialize(stream, fsize) + for findex in range(max_findex + 1, cls.N): + if active_fields.get(findex): + raise Exception(f"unknown field index {findex}") + return cls(**field_values) # type: ignore + + def serialize(self, stream: BinaryIO) -> int: + active_fields = self.active_fields() + num_prefix_bytes = active_fields.serialize(stream) + + num_data_bytes = sum( + ftyp.type_byte_length() if ftyp.is_fixed_byte_length() else OFFSET_BYTE_LENGTH + for findex, (_, (ftyp, _)) in enumerate(self.__class__.fields().items()) + if active_fields.get(findex)) + + temp_dyn_stream = io.BytesIO() + data = super().get_backing().get_left() + for findex, (_, (ftyp, _)) in enumerate(self.__class__.fields().items()): + if not active_fields.get(findex): + continue + fnode = data.getter(2**get_depth(self.__class__.N) + findex) + v = ftyp.view_from_backing(fnode) + if ftyp.is_fixed_byte_length(): + v.serialize(stream) + else: + encode_offset(stream, num_data_bytes) + num_data_bytes += v.serialize(temp_dyn_stream) # type: ignore + temp_dyn_stream.seek(0) + stream.write(temp_dyn_stream.read(num_data_bytes)) + + return num_prefix_bytes + num_data_bytes + + +class Variant(ComplexView): + def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None, **kwargs): + if backing is not None: + if len(kwargs) != 0: + raise Exception("cannot have both a backing and elements to init fields") + return super().__new__(cls, backing=backing, hook=hook, **kwargs) + + extra_kwargs = kwargs.copy() + for fkey, (ftyp, fopt) in cls.fields().items(): + if fkey in extra_kwargs: + extra_kwargs.pop(fkey) + elif not fopt: + raise AttributeError(f"Field '{fkey}' is required in {cls}") + else: + pass + if len(extra_kwargs) > 0: + raise AttributeError(f'The field names [{"".join(extra_kwargs.keys())}] are not defined in {cls}') + + value = cls.S(backing, hook, **kwargs) + return cls(backing=value.get_backing()) + + def __class_getitem__(cls, s) -> Type["Variant"]: + if not issubclass(s, StableContainer): + raise Exception(f"invalid variant container: {s}") + + class VariantView(Variant, s): + S = s + + @classmethod + def fields(cls) -> Dict[str, tuple[Type[View], bool]]: + return s.fields() + + VariantView.__name__ = VariantView.type_repr() + return VariantView + + @classmethod + def type_repr(cls) -> str: + return f"Variant[{cls.S.__name__}]" + + @classmethod + def deserialize(cls: Type[S], stream: BinaryIO, scope: int) -> S: + value = cls.S.deserialize(stream, scope) + return cls(backing=value.get_backing()) + + +class OneOf(ComplexView): + def __class_getitem__(cls, s) -> Type["OneOf"]: + if not issubclass(s, StableContainer) and not issubclass(s, Container): + raise Exception(f"invalid oneof container: {s}") + + class OneOfView(OneOf, s): + S = s + + @classmethod + def fields(cls): + return s.fields() + + OneOfView.__name__ = OneOfView.type_repr() + return OneOfView + + @classmethod + def type_repr(cls) -> str: + return f"OneOf[{cls.S}]" + + @classmethod + def decode_bytes(cls: Type[S], bytez: bytes, *args, **kwargs) -> S: + stream = io.BytesIO() + stream.write(bytez) + stream.seek(0) + return cls.deserialize(stream, len(bytez), *args, **kwargs) + + @classmethod + def deserialize(cls: Type[S], stream: BinaryIO, scope: int, *args, **kwargs) -> S: + value = cls.S.deserialize(stream, scope) + v = cls.select_variant(value, *args, **kwargs) + if not issubclass(v.S, cls.S): + raise Exception(f"unsupported select_variant result: {v}") + return v(backing=value.get_backing()) diff --git a/remerkleable/test_impl.py b/remerkleable/test_impl.py index 280b446..fcbf2e0 100644 --- a/remerkleable/test_impl.py +++ b/remerkleable/test_impl.py @@ -2,13 +2,14 @@ # flake8:noqa E501 Ignore long lines, some test cases are just inherently long -from typing import Iterable, Type +from typing import Iterable, Optional, Type import io from remerkleable.complex import Container, Vector, List from remerkleable.basic import boolean, bit, byte, uint8, uint16, uint32, uint64, uint128, uint256 from remerkleable.bitfields import Bitvector, Bitlist from remerkleable.byte_arrays import ByteVector, ByteList from remerkleable.core import View, ObjType +from remerkleable.stable_container import OneOf, StableContainer, Variant from remerkleable.union import Union from hashlib import sha256 @@ -475,3 +476,217 @@ class B(Container): assert A(1, 2, 3) != B(1, 2, 3, 0) assert A(1, 2, 3) in {A(1, 2, 3)} assert A(1, 2, 3) not in {B(1, 2, 3, 0)} + + +def test_stable_container(): + # Serialization and merkleization format + class Shape(StableContainer[4]): + side: Optional[uint16] + color: uint8 + radius: Optional[uint16] + + # Valid variants + class Square(Variant[Shape]): + side: uint16 + color: uint8 + + class Circle(Variant[Shape]): + radius: uint16 + color: uint8 + + class AnyShape(OneOf[Shape]): + @classmethod + def select_variant(cls, value: Shape, circle_allowed = False) -> Type[Shape]: + if value.radius is not None: + assert circle_allowed + return Circle + if value.side is not None: + return Square + assert False + + # Helper containers for merkleization testing + class ShapePayload(Container): + side: uint16 + color: uint8 + radius: uint16 + class ShapeRepr(Container): + value: ShapePayload + active_fields: Bitvector[4] + + # Square tests + shape1 = Shape(side=0x42, color=1, radius=None) + square_bytes = bytes.fromhex("03420001") + square1 = Square(side=0x42, color=1) + square2 = Square(backing=shape1.get_backing()) + square3 = Square(backing=square1.get_backing()) + assert shape1 == square1 == square2 == square3 + assert ( + shape1.encode_bytes() == square1.encode_bytes() == + square2.encode_bytes() == square3.encode_bytes() == + square_bytes + ) + assert ( + Shape.decode_bytes(square_bytes) == + Square.decode_bytes(square_bytes) == + AnyShape.decode_bytes(square_bytes) == + AnyShape.decode_bytes(square_bytes, circle_allowed = True) + ) + assert ( + shape1.hash_tree_root() == square1.hash_tree_root() == + square2.hash_tree_root() == square3.hash_tree_root() == + ShapeRepr( + value=ShapePayload(side=0x42, color=1, radius=0), + active_fields=Bitvector[4](True, True, False, False), + ).hash_tree_root() + ) + try: + circle = Circle(side=0x42, color=1) + assert False + except: + pass + try: + circle = Circle(backing=shape1.get_backing()) + assert False + except: + pass + try: + circle = Circle.decode_bytes(square_bytes) + assert False + except: + pass + shape1.side = 0x1337 + square1.side = 0x1337 + square2.side = 0x1337 + square3.side = 0x1337 + square_bytes = bytes.fromhex("03371301") + assert shape1 == square1 == square2 == square3 + assert ( + shape1.encode_bytes() == square1.encode_bytes() == + square2.encode_bytes() == square3.encode_bytes() == + square_bytes + ) + assert ( + Shape.decode_bytes(square_bytes) == + Square.decode_bytes(square_bytes) == + AnyShape.decode_bytes(square_bytes) == + AnyShape.decode_bytes(square_bytes, circle_allowed = True) + ) + assert ( + shape1.hash_tree_root() == square1.hash_tree_root() == + square2.hash_tree_root() == square3.hash_tree_root() == + ShapeRepr( + value=ShapePayload(side=0x1337, color=1, radius=0), + active_fields=Bitvector[4](True, True, False, False), + ).hash_tree_root() + ) + try: + square1.radius = 0x1337 + assert False + except: + pass + try: + square1.side = None + assert False + except: + pass + + # Circle tests + shape2 = Shape(side=None, color=1, radius=0x42) + circle_bytes = bytes.fromhex("06014200") + circle1 = Circle(radius=0x42, color=1) + circle2 = Circle(backing=shape2.get_backing()) + circle3 = Circle(backing=circle1.get_backing()) + circle4 = shape1 + circle4.side = None + circle4.radius = 0x42 + assert shape2 == circle1 == circle2 == circle3 == circle4 + assert ( + shape2.encode_bytes() == circle1.encode_bytes() == + circle2.encode_bytes() == circle3.encode_bytes() == + circle4.encode_bytes() == + circle_bytes + ) + assert ( + Shape.decode_bytes(circle_bytes) == + Circle.decode_bytes(circle_bytes) == + AnyShape.decode_bytes(circle_bytes, circle_allowed = True) + ) + assert ( + shape2.hash_tree_root() == circle1.hash_tree_root() == + circle2.hash_tree_root() == circle3.hash_tree_root() == + circle4.hash_tree_root() == + ShapeRepr( + value=ShapePayload(side=0, color=1, radius=0x42), + active_fields=Bitvector[4](False, True, True, False), + ).hash_tree_root() + ) + try: + square = Square(radius=0x42, color=1) + assert False + except: + pass + try: + square = Square(backing=shape2.get_backing()) + assert False + except: + pass + try: + square = Square.decode_bytes(circle_bytes) + assert False + except: + pass + try: + circle = AnyShape.decode_bytes(circle_bytes, circle_allowed = False) + assert False + except: + pass + + # Unsupported tests + shape3 = Shape(side=None, color=1, radius=None) + shape3_bytes = bytes.fromhex("0201") + assert shape3.encode_bytes() == shape3_bytes + assert Shape.decode_bytes(shape3_bytes) == shape3 + try: + shape = Square.decode_bytes(shape3_bytes) + assert False + except: + pass + try: + shape = Circle.decode_bytes(shape3_bytes) + assert False + except: + pass + try: + shape = AnyShape.decode_bytes(shape3_bytes) + assert False + except: + pass + shape4 = Shape(side=0x42, color=1, radius=0x42) + shape4_bytes = bytes.fromhex("074200014200") + assert shape4.encode_bytes() == shape4_bytes + assert Shape.decode_bytes(shape4_bytes) == shape4 + try: + shape = Square.decode_bytes(shape4_bytes) + assert False + except: + pass + try: + shape = Circle.decode_bytes(shape4_bytes) + assert False + except: + pass + try: + shape = AnyShape.decode_bytes(shape4_bytes) + assert False + except: + pass + try: + shape = AnyShape.decode_bytes("00") + assert False + except: + pass + try: + shape = Shape.decode_bytes("00") + assert False + except: + pass From 4fd59e995a8a0fbe3543286b40213cd18e7ff888 Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Mon, 4 Mar 2024 13:56:19 +0100 Subject: [PATCH 02/23] python 3.8 compat --- remerkleable/stable_container.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index cba54b9..c4cccb4 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -2,7 +2,7 @@ # The EIP is still under review, functionality may change or go away without deprecation. import io -from typing import BinaryIO, Dict, List as PyList, Optional, TypeVar, Type, Union as PyUnion, \ +from typing import BinaryIO, Dict, List as PyList, Optional, Tuple, TypeVar, Type, Union as PyUnion, \ get_args, get_origin from textwrap import indent from remerkleable.bitfields import Bitvector @@ -17,7 +17,7 @@ class StableContainer(ComplexView): - _field_indices: Dict[str, tuple[int, Type[View], bool]] + _field_indices: Dict[str, Tuple[int, Type[View], bool]] __slots__ = '_field_indices' def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None, **kwargs): @@ -75,7 +75,7 @@ class StableContainerView(StableContainer): return StableContainerView @classmethod - def fields(cls) -> Dict[str, tuple[Type[View], bool]]: + def fields(cls) -> Dict[str, Tuple[Type[View], bool]]: fields = {} for k, v in cls.__annotations__.items(): fopt = get_origin(v) == PyUnion and type(None) in get_args(v) @@ -281,7 +281,7 @@ class VariantView(Variant, s): S = s @classmethod - def fields(cls) -> Dict[str, tuple[Type[View], bool]]: + def fields(cls) -> Dict[str, Tuple[Type[View], bool]]: return s.fields() VariantView.__name__ = VariantView.type_repr() From 87e16cde8a1a030b93cfb82835518861cfe1b2ca Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Mon, 15 Apr 2024 18:13:59 +0200 Subject: [PATCH 03/23] bump for https://github.com/ethereum/EIPs/pull/8436 --- remerkleable/stable_container.py | 172 +++++++++++++++++-- remerkleable/test_impl.py | 276 ++++++++++++++++++------------- 2 files changed, 320 insertions(+), 128 deletions(-) diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index c4cccb4..57aa835 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -16,6 +16,15 @@ S = TypeVar('S', bound="ComplexView") +def all_fields(cls) -> Dict[str, Tuple[Type[View], bool]]: + fields = {} + for k, v in cls.__annotations__.items(): + fopt = get_origin(v) == PyUnion and type(None) in get_args(v) + ftyp = get_args(v)[0] if fopt else v + fields[k] = (ftyp, fopt) + return fields + + class StableContainer(ComplexView): _field_indices: Dict[str, Tuple[int, Type[View], bool]] __slots__ = '_field_indices' @@ -76,12 +85,7 @@ class StableContainerView(StableContainer): @classmethod def fields(cls) -> Dict[str, Tuple[Type[View], bool]]: - fields = {} - for k, v in cls.__annotations__.items(): - fopt = get_origin(v) == PyUnion and type(None) in get_args(v) - ftyp = get_args(v)[0] if fopt else v - fields[k] = (ftyp, fopt) - return fields + return all_fields(cls) @classmethod def is_fixed_byte_length(cls) -> bool: @@ -253,6 +257,8 @@ def serialize(self, stream: BinaryIO) -> int: class Variant(ComplexView): + _o: int + def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None, **kwargs): if backing is not None: if len(kwargs) != 0: @@ -273,6 +279,13 @@ def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None value = cls.S(backing, hook, **kwargs) return cls(backing=value.get_backing()) + def __init_subclass__(cls, *args, **kwargs): + super().__init_subclass__(*args, **kwargs) + cls._o = 0 + for _, (_, fopt) in cls.fields().items(): + if fopt: + cls._o += 1 + def __class_getitem__(cls, s) -> Type["Variant"]: if not issubclass(s, StableContainer): raise Exception(f"invalid variant container: {s}") @@ -280,21 +293,156 @@ def __class_getitem__(cls, s) -> Type["Variant"]: class VariantView(Variant, s): S = s - @classmethod - def fields(cls) -> Dict[str, Tuple[Type[View], bool]]: - return s.fields() - VariantView.__name__ = VariantView.type_repr() return VariantView + @classmethod + def fields(cls) -> Dict[str, Tuple[Type[View], bool]]: + return all_fields(cls) + + @classmethod + def is_fixed_byte_length(cls) -> bool: + if cls._o > 0: + return False + for _, (ftyp, _) in cls.fields().items(): + if not ftyp.is_fixed_byte_length(): + return False + return True + + @classmethod + def type_byte_length(cls) -> int: + if cls.is_fixed_byte_length(): + return cls.min_byte_length() + else: + raise Exception("dynamic length variant does not have a fixed byte length") + + @classmethod + def min_byte_length(cls) -> int: + total = Bitvector[cls._o].type_byte_length() if cls._o > 0 else 0 + for _, (ftyp, fopt) in cls.fields().items(): + if fopt: + continue + if not ftyp.is_fixed_byte_length(): + total += OFFSET_BYTE_LENGTH + total += ftyp.min_byte_length() + return total + + @classmethod + def max_byte_length(cls) -> int: + total = Bitvector[cls._o].type_byte_length() if cls._o > 0 else 0 + for _, (ftyp, _) in cls.fields().items(): + if not ftyp.is_fixed_byte_length(): + total += OFFSET_BYTE_LENGTH + total += ftyp.max_byte_length() + return total + + def active_fields(self) -> Bitvector: + active_fields_node = super().get_backing().get_right() + return Bitvector[self.__class__.S.N].view_from_backing(active_fields_node) + + def optional_fields(self) -> Bitvector: + assert self.__class__._o > 0 + active_fields = self.active_fields() + optional_fields = Bitvector[self.__class__._o]() + oindex = 0 + for fkey, (_, fopt) in self.__class__.fields().items(): + if fopt: + (findex, _, _) = self.__class__.S._field_indices[fkey] + optional_fields.set(oindex, active_fields.get(findex)) + oindex += 1 + return optional_fields + @classmethod def type_repr(cls) -> str: return f"Variant[{cls.S.__name__}]" @classmethod def deserialize(cls: Type[S], stream: BinaryIO, scope: int) -> S: - value = cls.S.deserialize(stream, scope) - return cls(backing=value.get_backing()) + if cls._o > 0: + num_prefix_bytes = Bitvector[cls._o].type_byte_length() + if scope < num_prefix_bytes: + raise ValueError("scope too small, cannot read Variant optional fields") + optional_fields = Bitvector[cls._o].deserialize(stream, num_prefix_bytes) + scope = scope - num_prefix_bytes + + field_values: Dict[str, Optional[View]] = {} + dyn_fields: PyList[FieldOffset] = [] + fixed_size = 0 + oindex = 0 + for fkey, (ftyp, fopt) in cls.fields().items(): + if fopt: + have_field = optional_fields.get(oindex) + oindex += 1 + if not have_field: + field_values[fkey] = None + continue + if ftyp.is_fixed_byte_length(): + fsize = ftyp.type_byte_length() + field_values[fkey] = ftyp.deserialize(stream, fsize) + fixed_size += fsize + else: + dyn_fields.append(FieldOffset( + key=fkey, typ=ftyp, offset=int(decode_offset(stream)))) + fixed_size += OFFSET_BYTE_LENGTH + assert oindex == cls._o + if len(dyn_fields) > 0: + if dyn_fields[0].offset < fixed_size: + raise Exception(f"first offset {dyn_fields[0].offset} is " + f"smaller than expected fixed size {fixed_size}") + for i, (fkey, ftyp, foffset) in enumerate(dyn_fields): + next_offset = dyn_fields[i + 1].offset if i + 1 < len(dyn_fields) else scope + if foffset > next_offset: + raise Exception(f"offset {i} is invalid: {foffset} " + f"larger than next offset {next_offset}") + fsize = next_offset - foffset + f_min_size, f_max_size = ftyp.min_byte_length(), ftyp.max_byte_length() + if not (f_min_size <= fsize <= f_max_size): + raise Exception(f"offset {i} is invalid, size out of bounds: " + f"{foffset}, next {next_offset}, implied size: {fsize}, " + f"size bounds: [{f_min_size}, {f_max_size}]") + field_values[fkey] = ftyp.deserialize(stream, fsize) + + return cls(**field_values) # type: ignore + + def serialize(self, stream: BinaryIO) -> int: + if self.__class__._o > 0: + optional_fields = self.optional_fields() + num_prefix_bytes = optional_fields.serialize(stream) + else: + num_prefix_bytes = 0 + + num_data_bytes = 0 + oindex = 0 + for _, (ftyp, fopt) in self.__class__.fields().items(): + if fopt: + have_field = optional_fields.get(oindex) + oindex += 1 + if not have_field: + continue + if ftyp.is_fixed_byte_length(): + num_data_bytes += ftyp.type_byte_length() + else: + num_data_bytes += OFFSET_BYTE_LENGTH + assert oindex == self.__class__._o + + temp_dyn_stream = io.BytesIO() + data = super().get_backing().get_left() + active_fields = self.active_fields() + for fkey, (ftyp, _) in self.__class__.fields().items(): + (findex, _, _) = self.__class__.S._field_indices[fkey] + if not active_fields.get(findex): + continue + fnode = data.getter(2**get_depth(self.__class__.N) + findex) + v = ftyp.view_from_backing(fnode) + if ftyp.is_fixed_byte_length(): + v.serialize(stream) + else: + encode_offset(stream, num_data_bytes) + num_data_bytes += v.serialize(temp_dyn_stream) # type: ignore + temp_dyn_stream.seek(0) + stream.write(temp_dyn_stream.read(num_data_bytes)) + + return num_prefix_bytes + num_data_bytes class OneOf(ComplexView): diff --git a/remerkleable/test_impl.py b/remerkleable/test_impl.py index fcbf2e0..01fc50b 100644 --- a/remerkleable/test_impl.py +++ b/remerkleable/test_impl.py @@ -479,17 +479,18 @@ class B(Container): def test_stable_container(): - # Serialization and merkleization format + # Defines the common merkleization format and a portable serialization format across variants class Shape(StableContainer[4]): side: Optional[uint16] color: uint8 radius: Optional[uint16] - # Valid variants + # Inherits merkleization format from `Shape`, but is serialized more compactly class Square(Variant[Shape]): side: uint16 color: uint8 + # Inherits merkleization format from `Shape`, but is serialized more compactly class Circle(Variant[Shape]): radius: uint16 color: uint8 @@ -514,179 +515,222 @@ class ShapeRepr(Container): active_fields: Bitvector[4] # Square tests - shape1 = Shape(side=0x42, color=1, radius=None) - square_bytes = bytes.fromhex("03420001") - square1 = Square(side=0x42, color=1) - square2 = Square(backing=shape1.get_backing()) - square3 = Square(backing=square1.get_backing()) - assert shape1 == square1 == square2 == square3 + square_bytes_stable = bytes.fromhex("03420001") + square_bytes_variant = bytes.fromhex("420001") + square_root = ShapeRepr( + value=ShapePayload(side=0x42, color=1, radius=0), + active_fields=Bitvector[4](True, True, False, False), + ).hash_tree_root() + shapes = [Shape(side=0x42, color=1, radius=None)] + squares = [Square(side=0x42, color=1)] + squares.extend(list(Square(backing=shape.get_backing()) for shape in shapes)) + shapes.extend(list(Shape(backing=shape.get_backing()) for shape in shapes)) + shapes.extend(list(Shape(backing=square.get_backing()) for square in squares)) + squares.extend(list(Square(backing=square.get_backing()) for square in squares)) + assert len(set(shapes)) == 1 + assert len(set(squares)) == 1 + assert all(shape.encode_bytes() == square_bytes_stable for shape in shapes) + assert all(square.encode_bytes() == square_bytes_variant for square in squares) assert ( - shape1.encode_bytes() == square1.encode_bytes() == - square2.encode_bytes() == square3.encode_bytes() == - square_bytes - ) - assert ( - Shape.decode_bytes(square_bytes) == - Square.decode_bytes(square_bytes) == - AnyShape.decode_bytes(square_bytes) == - AnyShape.decode_bytes(square_bytes, circle_allowed = True) - ) - assert ( - shape1.hash_tree_root() == square1.hash_tree_root() == - square2.hash_tree_root() == square3.hash_tree_root() == - ShapeRepr( - value=ShapePayload(side=0x42, color=1, radius=0), - active_fields=Bitvector[4](True, True, False, False), - ).hash_tree_root() + Square(backing=Shape.decode_bytes(square_bytes_stable).get_backing()) == + Square.decode_bytes(square_bytes_variant) == + AnyShape.decode_bytes(square_bytes_stable) == + AnyShape.decode_bytes(square_bytes_stable, circle_allowed = True) ) + assert all(shape.hash_tree_root() == square_root for shape in shapes) + assert all(square.hash_tree_root() == square_root for square in squares) try: circle = Circle(side=0x42, color=1) assert False except: pass - try: - circle = Circle(backing=shape1.get_backing()) - assert False - except: - pass - try: - circle = Circle.decode_bytes(square_bytes) - assert False - except: - pass - shape1.side = 0x1337 - square1.side = 0x1337 - square2.side = 0x1337 - square3.side = 0x1337 - square_bytes = bytes.fromhex("03371301") - assert shape1 == square1 == square2 == square3 - assert ( - shape1.encode_bytes() == square1.encode_bytes() == - square2.encode_bytes() == square3.encode_bytes() == - square_bytes - ) - assert ( - Shape.decode_bytes(square_bytes) == - Square.decode_bytes(square_bytes) == - AnyShape.decode_bytes(square_bytes) == - AnyShape.decode_bytes(square_bytes, circle_allowed = True) - ) + for shape in shapes: + try: + circle = Circle(backing=shape.get_backing()) + assert False + except: + pass + for square in squares: + try: + circle = Circle(backing=square.get_backing()) + assert False + except: + pass + for shape in shapes: + shape.side = 0x1337 + for square in squares: + square.side = 0x1337 + square_bytes_stable = bytes.fromhex("03371301") + square_bytes_variant = bytes.fromhex("371301") + square_root = ShapeRepr( + value=ShapePayload(side=0x1337, color=1, radius=0), + active_fields=Bitvector[4](True, True, False, False), + ).hash_tree_root() + assert len(set(shapes)) == 1 + assert len(set(squares)) == 1 + assert all(shape.encode_bytes() == square_bytes_stable for shape in shapes) + assert all(square.encode_bytes() == square_bytes_variant for square in squares) assert ( - shape1.hash_tree_root() == square1.hash_tree_root() == - square2.hash_tree_root() == square3.hash_tree_root() == - ShapeRepr( - value=ShapePayload(side=0x1337, color=1, radius=0), - active_fields=Bitvector[4](True, True, False, False), - ).hash_tree_root() + Square(backing=Shape.decode_bytes(square_bytes_stable).get_backing()) == + Square.decode_bytes(square_bytes_variant) == + AnyShape.decode_bytes(square_bytes_stable) == + AnyShape.decode_bytes(square_bytes_stable, circle_allowed = True) ) - try: - square1.radius = 0x1337 - assert False - except: - pass - try: - square1.side = None - assert False - except: - pass + assert all(shape.hash_tree_root() == square_root for shape in shapes) + assert all(square.hash_tree_root() == square_root for square in squares) + for square in squares: + try: + square.radius = 0x1337 + assert False + except: + pass + for square in squares: + try: + square.side = None + assert False + except: + pass # Circle tests - shape2 = Shape(side=None, color=1, radius=0x42) - circle_bytes = bytes.fromhex("06014200") - circle1 = Circle(radius=0x42, color=1) - circle2 = Circle(backing=shape2.get_backing()) - circle3 = Circle(backing=circle1.get_backing()) - circle4 = shape1 - circle4.side = None - circle4.radius = 0x42 - assert shape2 == circle1 == circle2 == circle3 == circle4 + circle_bytes_stable = bytes.fromhex("06014200") + circle_bytes_variant = bytes.fromhex("420001") + circle_root = ShapeRepr( + value=ShapePayload(side=0, color=1, radius=0x42), + active_fields=Bitvector[4](False, True, True, False), + ).hash_tree_root() + modified_shape = shapes[0] + modified_shape.side = None + modified_shape.radius = 0x42 + shapes = [Shape(side=None, color=1, radius=0x42), modified_shape] + circles = [Circle(radius=0x42, color=1)] + circles.extend(list(Circle(backing=shape.get_backing()) for shape in shapes)) + shapes.extend(list(Shape(backing=shape.get_backing()) for shape in shapes)) + shapes.extend(list(Shape(backing=circle.get_backing()) for circle in circles)) + circles.extend(list(Circle(backing=circle.get_backing()) for circle in circles)) + assert len(set(shapes)) == 1 + assert len(set(circles)) == 1 + assert all(shape.encode_bytes() == circle_bytes_stable for shape in shapes) + assert all(circle.encode_bytes() == circle_bytes_variant for circle in circles) assert ( - shape2.encode_bytes() == circle1.encode_bytes() == - circle2.encode_bytes() == circle3.encode_bytes() == - circle4.encode_bytes() == - circle_bytes - ) - assert ( - Shape.decode_bytes(circle_bytes) == - Circle.decode_bytes(circle_bytes) == - AnyShape.decode_bytes(circle_bytes, circle_allowed = True) - ) - assert ( - shape2.hash_tree_root() == circle1.hash_tree_root() == - circle2.hash_tree_root() == circle3.hash_tree_root() == - circle4.hash_tree_root() == - ShapeRepr( - value=ShapePayload(side=0, color=1, radius=0x42), - active_fields=Bitvector[4](False, True, True, False), - ).hash_tree_root() + Circle(backing=Shape.decode_bytes(circle_bytes_stable).get_backing()) == + Circle.decode_bytes(circle_bytes_variant) == + AnyShape.decode_bytes(circle_bytes_stable, circle_allowed = True) ) + assert all(shape.hash_tree_root() == circle_root for shape in shapes) + assert all(circle.hash_tree_root() == circle_root for circle in circles) try: square = Square(radius=0x42, color=1) assert False except: pass + for shape in shapes: + try: + square = Square(backing=shape.get_backing()) + assert False + except: + pass + for circle in circles: + try: + square = Square(backing=circle.get_backing()) + assert False + except: + pass try: - square = Square(backing=shape2.get_backing()) + circle = AnyShape.decode_bytes(circle_bytes_stable, circle_allowed = False) assert False except: pass + + # Unsupported tests + shape = Shape(side=None, color=1, radius=None) + shape_bytes = bytes.fromhex("0201") + assert shape.encode_bytes() == shape_bytes + assert Shape.decode_bytes(shape_bytes) == shape try: - square = Square.decode_bytes(circle_bytes) + shape = Square.decode_bytes(shape_bytes) assert False except: pass try: - circle = AnyShape.decode_bytes(circle_bytes, circle_allowed = False) + shape = Circle.decode_bytes(shape_bytes) assert False except: pass - - # Unsupported tests - shape3 = Shape(side=None, color=1, radius=None) - shape3_bytes = bytes.fromhex("0201") - assert shape3.encode_bytes() == shape3_bytes - assert Shape.decode_bytes(shape3_bytes) == shape3 try: - shape = Square.decode_bytes(shape3_bytes) + shape = AnyShape.decode_bytes(shape_bytes) assert False except: pass + shape = Shape(side=0x42, color=1, radius=0x42) + shape_bytes = bytes.fromhex("074200014200") + assert shape.encode_bytes() == shape_bytes + assert Shape.decode_bytes(shape_bytes) == shape try: - shape = Circle.decode_bytes(shape3_bytes) + shape = Square.decode_bytes(shape_bytes) assert False except: pass try: - shape = AnyShape.decode_bytes(shape3_bytes) + shape = Circle.decode_bytes(shape_bytes) assert False except: pass - shape4 = Shape(side=0x42, color=1, radius=0x42) - shape4_bytes = bytes.fromhex("074200014200") - assert shape4.encode_bytes() == shape4_bytes - assert Shape.decode_bytes(shape4_bytes) == shape4 try: - shape = Square.decode_bytes(shape4_bytes) + shape = AnyShape.decode_bytes(shape_bytes) assert False except: pass try: - shape = Circle.decode_bytes(shape4_bytes) + shape = AnyShape.decode_bytes("00") assert False except: pass try: - shape = AnyShape.decode_bytes(shape4_bytes) + shape = Shape.decode_bytes("00") assert False except: pass try: - shape = AnyShape.decode_bytes("00") + square = Square(radius=0x42, color=1) assert False except: pass try: - shape = Shape.decode_bytes("00") + circle = Circle(side=0x42, color=1) assert False except: pass + + # Surrounding container tests + class ShapeContainer(Container): + shape: Shape + square: Square + circle: Circle + + class ShapeContainerRepr(Container): + shape: ShapeRepr + square: ShapeRepr + circle: ShapeRepr + + container = ShapeContainer( + shape=Shape(side=0x42, color=1, radius=0x42), + square=Square(side=0x42, color=1), + circle=Circle(radius=0x42, color=1), + ) + container_bytes = bytes.fromhex("0a000000420001420001074200014200") + assert container.encode_bytes() == container_bytes + assert ShapeContainer.decode_bytes(container_bytes) == container + assert container.hash_tree_root() == ShapeContainerRepr( + shape=ShapeRepr( + value=ShapePayload(side=0x42, color=1, radius=0x42), + active_fields=Bitvector[4](True, True, True, False), + ), + square=ShapeRepr( + value=ShapePayload(side=0x42, color=1, radius=0), + active_fields=Bitvector[4](True, True, False, False), + ), + circle=ShapeRepr( + value=ShapePayload(side=0, color=1, radius=0x42), + active_fields=Bitvector[4](False, True, True, False), + ), + ).hash_tree_root() From 6c8c20232d61fda0912521a39f83f05c0a1ebe4d Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Mon, 6 May 2024 22:24:05 +0200 Subject: [PATCH 04/23] fix property accessors for `Variant[S]` --- remerkleable/stable_container.py | 94 ++++++++++++++++++++++++++------ 1 file changed, 78 insertions(+), 16 deletions(-) diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index 57aa835..b88eeb3 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -25,6 +25,26 @@ def all_fields(cls) -> Dict[str, Tuple[Type[View], bool]]: return fields +def field_val_repr(self, fkey: str, ftyp: Type[View], fopt: bool) -> str: + field_start = ' ' + fkey + ': ' + ( + ('Optional[' if fopt else '') + ftyp.__name__ + (']' if fopt else '') + ) + ' = ' + try: + field_repr = repr(getattr(self, fkey)) + if '\n' in field_repr: # if multiline, indent it, but starting from the value. + i = field_repr.index('\n') + field_repr = field_repr[:i+1] + indent(field_repr[i+1:], ' ' * len(field_start)) + return field_start + field_repr + except NavigationError: + return f"{field_start} *omitted*" + + +def repr(self) -> str: + return f"{self.__class__.type_repr()}:\n" + '\n'.join( + indent(field_val_repr(self, fkey, ftyp, fopt), ' ') + for fkey, (ftyp, fopt) in self.__class__.fields().items()) + + class StableContainer(ComplexView): _field_indices: Dict[str, Tuple[int, Type[View], bool]] __slots__ = '_field_indices' @@ -161,23 +181,8 @@ def __setattr__(self, key, value): self.set_backing(next_backing) - def _get_field_val_repr(self, fkey: str, ftyp: Type[View], fopt: bool) -> str: - field_start = ' ' + fkey + ': ' + ( - ('Optional[' if fopt else '') + ftyp.__name__ + (']' if fopt else '') - ) + ' = ' - try: - field_repr = repr(getattr(self, fkey)) - if '\n' in field_repr: # if multiline, indent it, but starting from the value. - i = field_repr.index('\n') - field_repr = field_repr[:i+1] + indent(field_repr[i+1:], ' ' * len(field_start)) - return field_start + field_repr - except NavigationError: - return f"{field_start} *omitted from partial*" - def __repr__(self): - return f"{self.__class__.type_repr()}:\n" + '\n'.join( - indent(self._get_field_val_repr(fkey, ftyp, fopt), ' ') - for fkey, (ftyp, fopt) in self.__class__.fields().items()) + return repr(self) @classmethod def type_repr(cls) -> str: @@ -352,6 +357,63 @@ def optional_fields(self) -> Bitvector: oindex += 1 return optional_fields + def __getattr__(self, item): + if item[0] == '_': + return super().__getattribute__(item) + else: + try: + (ftyp, fopt) = self.__class__.fields()[item] + except KeyError: + raise AttributeError(f"unknown attribute {item}") + try: + (findex, _, _) = self.__class__.S._field_indices[item] + except KeyError: + raise AttributeError(f"unknown attribute {item} in base") + + if not self.active_fields().get(findex): + assert fopt + return None + + data = super().get_backing().get_left() + fnode = data.getter(2**get_depth(self.__class__.S.N) + findex) + return ftyp.view_from_backing(fnode) + + def __setattr__(self, key, value): + if key[0] == '_': + super().__setattr__(key, value) + else: + try: + (ftyp, fopt) = self.__class__.fields()[key] + except KeyError: + raise AttributeError(f"unknown attribute {key}") + try: + (findex, _, _) = self.__class__.S._field_indices[key] + except KeyError: + raise AttributeError(f"unknown attribute {key} in base") + + next_backing = self.get_backing() + + assert value is not None or fopt + active_fields = self.active_fields() + active_fields.set(findex, value is not None) + next_backing = next_backing.rebind_right(active_fields.get_backing()) + + if value is not None: + if isinstance(value, ftyp): + fnode = value.get_backing() + else: + fnode = ftyp.coerce_view(value).get_backing() + else: + fnode = zero_node(0) + data = next_backing.get_left() + next_data = data.setter(2**get_depth(self.__class__.S.N) + findex)(fnode) + next_backing = next_backing.rebind_left(next_data) + + self.set_backing(next_backing) + + def __repr__(self): + return repr(self) + @classmethod def type_repr(cls) -> str: return f"Variant[{cls.S.__name__}]" From cd94c6c70dc2087a13438fcff904199417def055 Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Mon, 13 May 2024 22:13:14 +0300 Subject: [PATCH 05/23] fix printing of objects --- remerkleable/stable_container.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index b88eeb3..ca64254 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -30,7 +30,7 @@ def field_val_repr(self, fkey: str, ftyp: Type[View], fopt: bool) -> str: ('Optional[' if fopt else '') + ftyp.__name__ + (']' if fopt else '') ) + ' = ' try: - field_repr = repr(getattr(self, fkey)) + field_repr = getattr(self, fkey).__repr__() if '\n' in field_repr: # if multiline, indent it, but starting from the value. i = field_repr.index('\n') field_repr = field_repr[:i+1] + indent(field_repr[i+1:], ' ' * len(field_start)) @@ -522,6 +522,9 @@ def fields(cls): OneOfView.__name__ = OneOfView.type_repr() return OneOfView + def __repr__(self): + return repr(self) + @classmethod def type_repr(cls) -> str: return f"OneOf[{cls.S}]" From 2a669767f17a8a0aa56cf57f8421198740c18e8d Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Tue, 14 May 2024 12:13:19 +0300 Subject: [PATCH 06/23] Rename `Variant` > `MerkleizeAs` and allow `MerkleizeAs[Container]` --- remerkleable/stable_container.py | 96 +++++++++++++++----------- remerkleable/test_impl.py | 112 +++++++++++++++++++++++++++++-- 2 files changed, 166 insertions(+), 42 deletions(-) diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index ca64254..c5405de 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -13,6 +13,7 @@ get_depth, subtree_fill_to_contents, zero_node N = TypeVar('N') +B = TypeVar('B', bound="ComplexView") S = TypeVar('S', bound="ComplexView") @@ -261,7 +262,7 @@ def serialize(self, stream: BinaryIO) -> int: return num_prefix_bytes + num_data_bytes -class Variant(ComplexView): +class MerkleizeAs(ComplexView): _o: int def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None, **kwargs): @@ -281,7 +282,7 @@ def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None if len(extra_kwargs) > 0: raise AttributeError(f'The field names [{"".join(extra_kwargs.keys())}] are not defined in {cls}') - value = cls.S(backing, hook, **kwargs) + value = cls.B(backing, hook, **kwargs) return cls(backing=value.get_backing()) def __init_subclass__(cls, *args, **kwargs): @@ -290,16 +291,17 @@ def __init_subclass__(cls, *args, **kwargs): for _, (_, fopt) in cls.fields().items(): if fopt: cls._o += 1 + assert cls._o == 0 or issubclass(cls.B, StableContainer) - def __class_getitem__(cls, s) -> Type["Variant"]: - if not issubclass(s, StableContainer): - raise Exception(f"invalid variant container: {s}") + def __class_getitem__(cls, b) -> Type["MerkleizeAs"]: + if not issubclass(b, StableContainer) and not issubclass(b, Container): + raise Exception(f"invalid MerkleizeAs base: {b}") - class VariantView(Variant, s): - S = s + class MerkleizeAsView(MerkleizeAs, b): + B = b - VariantView.__name__ = VariantView.type_repr() - return VariantView + MerkleizeAsView.__name__ = MerkleizeAsView.type_repr() + return MerkleizeAsView @classmethod def fields(cls) -> Dict[str, Tuple[Type[View], bool]]: @@ -319,7 +321,7 @@ def type_byte_length(cls) -> int: if cls.is_fixed_byte_length(): return cls.min_byte_length() else: - raise Exception("dynamic length variant does not have a fixed byte length") + raise Exception("dynamic length MerkleizeAs does not have a fixed byte length") @classmethod def min_byte_length(cls) -> int: @@ -342,17 +344,19 @@ def max_byte_length(cls) -> int: return total def active_fields(self) -> Bitvector: + assert issubclass(self.__class__.B, StableContainer) active_fields_node = super().get_backing().get_right() - return Bitvector[self.__class__.S.N].view_from_backing(active_fields_node) + return Bitvector[self.__class__.B.N].view_from_backing(active_fields_node) def optional_fields(self) -> Bitvector: + assert issubclass(self.__class__.B, StableContainer) assert self.__class__._o > 0 active_fields = self.active_fields() optional_fields = Bitvector[self.__class__._o]() oindex = 0 for fkey, (_, fopt) in self.__class__.fields().items(): if fopt: - (findex, _, _) = self.__class__.S._field_indices[fkey] + (findex, _, _) = self.__class__.B._field_indices[fkey] optional_fields.set(oindex, active_fields.get(findex)) oindex += 1 return optional_fields @@ -366,16 +370,19 @@ def __getattr__(self, item): except KeyError: raise AttributeError(f"unknown attribute {item}") try: - (findex, _, _) = self.__class__.S._field_indices[item] + (findex, _, _) = self.__class__.B._field_indices[item] except KeyError: raise AttributeError(f"unknown attribute {item} in base") + if not issubclass(self.__class__.B, StableContainer): + return super().get(findex) + if not self.active_fields().get(findex): assert fopt return None data = super().get_backing().get_left() - fnode = data.getter(2**get_depth(self.__class__.S.N) + findex) + fnode = data.getter(2**get_depth(self.__class__.B.N) + findex) return ftyp.view_from_backing(fnode) def __setattr__(self, key, value): @@ -387,10 +394,14 @@ def __setattr__(self, key, value): except KeyError: raise AttributeError(f"unknown attribute {key}") try: - (findex, _, _) = self.__class__.S._field_indices[key] + (findex, _, _) = self.__class__.B._field_indices[key] except KeyError: raise AttributeError(f"unknown attribute {key} in base") + if not issubclass(self.__class__.B, StableContainer): + super().set(findex, value) + return + next_backing = self.get_backing() assert value is not None or fopt @@ -406,7 +417,7 @@ def __setattr__(self, key, value): else: fnode = zero_node(0) data = next_backing.get_left() - next_data = data.setter(2**get_depth(self.__class__.S.N) + findex)(fnode) + next_data = data.setter(2**get_depth(self.__class__.B.N) + findex)(fnode) next_backing = next_backing.rebind_left(next_data) self.set_backing(next_backing) @@ -416,14 +427,14 @@ def __repr__(self): @classmethod def type_repr(cls) -> str: - return f"Variant[{cls.S.__name__}]" + return f"MerkleizeAs[{cls.B.__name__}]" @classmethod - def deserialize(cls: Type[S], stream: BinaryIO, scope: int) -> S: + def deserialize(cls: Type[B], stream: BinaryIO, scope: int) -> B: if cls._o > 0: num_prefix_bytes = Bitvector[cls._o].type_byte_length() if scope < num_prefix_bytes: - raise ValueError("scope too small, cannot read Variant optional fields") + raise ValueError("scope too small, cannot read MerkleizeAs optional fields") optional_fields = Bitvector[cls._o].deserialize(stream, num_prefix_bytes) scope = scope - num_prefix_bytes @@ -488,13 +499,22 @@ def serialize(self, stream: BinaryIO) -> int: assert oindex == self.__class__._o temp_dyn_stream = io.BytesIO() - data = super().get_backing().get_left() - active_fields = self.active_fields() + if issubclass(self.__class__.B, StableContainer): + data = super().get_backing().get_left() + active_fields = self.active_fields() + n = self.__class__.B.N + else: + data = super().get_backing() + n = len(self.__class__.B.fields()) for fkey, (ftyp, _) in self.__class__.fields().items(): - (findex, _, _) = self.__class__.S._field_indices[fkey] - if not active_fields.get(findex): - continue - fnode = data.getter(2**get_depth(self.__class__.N) + findex) + if issubclass(self.__class__.B, StableContainer): + (findex, _, _) = self.__class__.B._field_indices[fkey] + if not active_fields.get(findex): + continue + fnode = data.getter(2**get_depth(n) + findex) + else: + findex = self.__class__.B._field_indices[fkey] + fnode = data.getter(2**get_depth(n) + findex) v = ftyp.view_from_backing(fnode) if ftyp.is_fixed_byte_length(): v.serialize(stream) @@ -508,16 +528,16 @@ def serialize(self, stream: BinaryIO) -> int: class OneOf(ComplexView): - def __class_getitem__(cls, s) -> Type["OneOf"]: - if not issubclass(s, StableContainer) and not issubclass(s, Container): - raise Exception(f"invalid oneof container: {s}") + def __class_getitem__(cls, b) -> Type["OneOf"]: + if not issubclass(b, StableContainer) and not issubclass(b, Container): + raise Exception(f"invalid OneOf base: {b}") - class OneOfView(OneOf, s): - S = s + class OneOfView(OneOf, b): + B = b @classmethod def fields(cls): - return s.fields() + return b.fields() OneOfView.__name__ = OneOfView.type_repr() return OneOfView @@ -527,19 +547,19 @@ def __repr__(self): @classmethod def type_repr(cls) -> str: - return f"OneOf[{cls.S}]" + return f"OneOf[{cls.B}]" @classmethod - def decode_bytes(cls: Type[S], bytez: bytes, *args, **kwargs) -> S: + def decode_bytes(cls: Type[B], bytez: bytes, *args, **kwargs) -> B: stream = io.BytesIO() stream.write(bytez) stream.seek(0) return cls.deserialize(stream, len(bytez), *args, **kwargs) @classmethod - def deserialize(cls: Type[S], stream: BinaryIO, scope: int, *args, **kwargs) -> S: - value = cls.S.deserialize(stream, scope) - v = cls.select_variant(value, *args, **kwargs) - if not issubclass(v.S, cls.S): - raise Exception(f"unsupported select_variant result: {v}") + def deserialize(cls: Type[B], stream: BinaryIO, scope: int, *args, **kwargs) -> B: + value = cls.B.deserialize(stream, scope) + v = cls.select_from_base(value, *args, **kwargs) + if not issubclass(v.B, cls.B): + raise Exception(f"unsupported select_from_base result: {v}") return v(backing=value.get_backing()) diff --git a/remerkleable/test_impl.py b/remerkleable/test_impl.py index 01fc50b..fba3c53 100644 --- a/remerkleable/test_impl.py +++ b/remerkleable/test_impl.py @@ -9,7 +9,7 @@ from remerkleable.bitfields import Bitvector, Bitlist from remerkleable.byte_arrays import ByteVector, ByteList from remerkleable.core import View, ObjType -from remerkleable.stable_container import OneOf, StableContainer, Variant +from remerkleable.stable_container import OneOf, StableContainer, MerkleizeAs from remerkleable.union import Union from hashlib import sha256 @@ -486,18 +486,18 @@ class Shape(StableContainer[4]): radius: Optional[uint16] # Inherits merkleization format from `Shape`, but is serialized more compactly - class Square(Variant[Shape]): + class Square(MerkleizeAs[Shape]): side: uint16 color: uint8 # Inherits merkleization format from `Shape`, but is serialized more compactly - class Circle(Variant[Shape]): + class Circle(MerkleizeAs[Shape]): radius: uint16 color: uint8 class AnyShape(OneOf[Shape]): @classmethod - def select_variant(cls, value: Shape, circle_allowed = False) -> Type[Shape]: + def select_from_base(cls, value: Shape, circle_allowed = False) -> Type[Shape]: if value.radius is not None: assert circle_allowed return Circle @@ -505,6 +505,19 @@ def select_variant(cls, value: Shape, circle_allowed = False) -> Type[Shape]: return Square assert False + # Compounds + class ShapePair(Container): + shape_1: Shape + shape_2: Shape + + class SquarePair(MerkleizeAs[ShapePair]): + shape_1: Square + shape_2: Square + + class CirclePair(MerkleizeAs[ShapePair]): + shape_2: Circle + shape_1: Circle + # Helper containers for merkleization testing class ShapePayload(Container): side: uint16 @@ -514,6 +527,22 @@ class ShapeRepr(Container): value: ShapePayload active_fields: Bitvector[4] + class ShapePairRepr(Container): + shape_1: ShapeRepr + shape_2: ShapeRepr + + class AnyShapePair(OneOf[ShapePair]): + @classmethod + def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[ShapePair]: + typ_1 = AnyShape.select_from_base(value.shape_1, circle_allowed) + typ_2 = AnyShape.select_from_base(value.shape_2, circle_allowed) + assert typ_1 == typ_2 + if typ_1 is Circle: + return CirclePair + if typ_1 is Square: + return SquarePair + assert False + # Square tests square_bytes_stable = bytes.fromhex("03420001") square_bytes_variant = bytes.fromhex("420001") @@ -641,6 +670,81 @@ class ShapeRepr(Container): except: pass + # SquarePair tests + square_pair_bytes_stable = bytes.fromhex("080000000c0000000342000103690001") + square_pair_bytes_variant = bytes.fromhex("420001690001") + square_pair_root = ShapePairRepr( + shape_1=ShapeRepr( + value=ShapePayload(side=0x42, color=1, radius=0), + active_fields=Bitvector[4](True, True, False, False), + ), + shape_2=ShapeRepr( + value=ShapePayload(side=0x69, color=1, radius=0), + active_fields=Bitvector[4](True, True, False, False), + ) + ).hash_tree_root() + shape_pairs = [ShapePair( + shape_1=Shape(side=0x42, color=1, radius=None), + shape_2=Shape(side=0x69, color=1, radius=None), + )] + square_pairs = [SquarePair( + shape_1=Square(side=0x42, color=1), + shape_2=Square(side=0x69, color=1), + )] + square_pairs.extend(list(SquarePair(backing=pair.get_backing()) for pair in shape_pairs)) + shape_pairs.extend(list(ShapePair(backing=pair.get_backing()) for pair in shape_pairs)) + shape_pairs.extend(list(ShapePair(backing=pair.get_backing()) for pair in square_pairs)) + square_pairs.extend(list(SquarePair(backing=pair.get_backing()) for pair in square_pairs)) + assert len(set(shape_pairs)) == 1 + assert len(set(square_pairs)) == 1 + assert all(pair.encode_bytes() == square_pair_bytes_stable for pair in shape_pairs) + assert all(pair.encode_bytes() == square_pair_bytes_variant for pair in square_pairs) + assert ( + SquarePair(backing=ShapePair.decode_bytes(square_pair_bytes_stable).get_backing()) == + SquarePair.decode_bytes(square_pair_bytes_variant) == + AnyShapePair.decode_bytes(square_pair_bytes_stable) == + AnyShapePair.decode_bytes(square_pair_bytes_stable, circle_allowed = True) + ) + assert all(pair.hash_tree_root() == square_pair_root for pair in shape_pairs) + assert all(pair.hash_tree_root() == square_pair_root for pair in square_pairs) + + # CirclePair tests + circle_pair_bytes_stable = bytes.fromhex("080000000c0000000601420006016900") + circle_pair_bytes_variant = bytes.fromhex("690001420001") + circle_pair_root = ShapePairRepr( + shape_1=ShapeRepr( + value=ShapePayload(side=0, color=1, radius=0x42), + active_fields=Bitvector[4](False, True, True, False), + ), + shape_2=ShapeRepr( + value=ShapePayload(side=0, color=1, radius=0x69), + active_fields=Bitvector[4](False, True, True, False), + ) + ).hash_tree_root() + shape_pairs = [ShapePair( + shape_1=Shape(side=None, color=1, radius=0x42), + shape_2=Shape(side=None, color=1, radius=0x69), + )] + circle_pairs = [CirclePair( + shape_1=Circle(radius=0x42, color=1), + shape_2=Circle(radius=0x69, color=1), + )] + circle_pairs.extend(list(CirclePair(backing=pair.get_backing()) for pair in shape_pairs)) + shape_pairs.extend(list(ShapePair(backing=pair.get_backing()) for pair in shape_pairs)) + shape_pairs.extend(list(ShapePair(backing=pair.get_backing()) for pair in circle_pairs)) + circle_pairs.extend(list(CirclePair(backing=pair.get_backing()) for pair in circle_pairs)) + assert len(set(shape_pairs)) == 1 + assert len(set(circle_pairs)) == 1 + assert all(pair.encode_bytes() == circle_pair_bytes_stable for pair in shape_pairs) + assert all(pair.encode_bytes() == circle_pair_bytes_variant for pair in circle_pairs) + assert ( + CirclePair(backing=ShapePair.decode_bytes(circle_pair_bytes_stable).get_backing()) == + CirclePair.decode_bytes(circle_pair_bytes_variant) == + AnyShapePair.decode_bytes(circle_pair_bytes_stable, circle_allowed = True) + ) + assert all(pair.hash_tree_root() == circle_pair_root for pair in shape_pairs) + assert all(pair.hash_tree_root() == circle_pair_root for pair in circle_pairs) + # Unsupported tests shape = Shape(side=None, color=1, radius=None) shape_bytes = bytes.fromhex("0201") From 499056ca8428c2834c619b55b38f2034ff64a938 Mon Sep 17 00:00:00 2001 From: Cayman Date: Tue, 14 May 2024 10:28:38 -0400 Subject: [PATCH 07/23] chore: add more StableContainer tests --- remerkleable/test_impl.py | 91 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/remerkleable/test_impl.py b/remerkleable/test_impl.py index fba3c53..d72779e 100644 --- a/remerkleable/test_impl.py +++ b/remerkleable/test_impl.py @@ -838,3 +838,94 @@ class ShapeContainerRepr(Container): active_fields=Bitvector[4](False, True, True, False), ), ).hash_tree_root() + + +# basic container +class Shape1(StableContainer[4]): + side: Optional[uint16] + color: uint8 + radius: Optional[uint16] + +# basic container with different depth +class Shape2(StableContainer[8]): + side: Optional[uint16] + color: uint8 + radius: Optional[uint16] + +# basic container with variable fields +class Shape3(StableContainer[8]): + side: Optional[uint16] + colors: Optional[List[uint8, 4]] + radius: Optional[uint16] + +stable_container_tests = [ + { + 'value': Shape1(side=0x42, color=1, radius=0x42), + 'serialized': '074200014200', + 'hash_tree_root': '37b28eab19bc3e246e55d2e2b2027479454c27ee006d92d4847c84893a162e6d' + }, + { + 'value': Shape1(side=0x42, color=1, radius=None), + 'serialized': '03420001', + 'hash_tree_root': 'bfdb6fda9d02805e640c0f5767b8d1bb9ff4211498a5e2d7c0f36e1b88ce57ff' + }, + { + 'value': Shape1(side=None, color=1, radius=None), + 'serialized': '0201', + 'hash_tree_root': '522edd7309c0041b8eb6a218d756af558e9cf4c816441ec7e6eef42dfa47bb98' + }, + { + 'value': Shape1(side=None, color=1, radius=0x42), + 'serialized': '06014200', + 'hash_tree_root': 'f66d2c38c8d2afbd409e86c529dff728e9a4208215ca20ee44e49c3d11e145d8' + }, + { + 'value': Shape2(side=0x42, color=1, radius=0x42), + 'serialized': '074200014200', + 'hash_tree_root': '0792fb509377ee2ff3b953dd9a88eee11ac7566a8df41c6c67a85bc0b53efa4e' + }, + { + 'value': Shape2(side=0x42, color=1, radius=None), + 'serialized': '03420001', + 'hash_tree_root': 'ddc7acd38ae9d6d6788c14bd7635aeb1d7694768d7e00e1795bb6d328ec14f28' + }, + { + 'value': Shape2(side=None, color=1, radius=None), + 'serialized': '0201', + 'hash_tree_root': '9893ecf9b68030ff23c667a5f2e4a76538a8e2ab48fd060a524888a66fb938c9' + }, + { + 'value': Shape2(side=None, color=1, radius=0x42), + 'serialized': '06014200', + 'hash_tree_root': 'e823471310312d52aa1135d971a3ed72ba041ade3ec5b5077c17a39d73ab17c5' + }, + { + 'value': Shape3(side=0x42, colors=[1, 2], radius=0x42), + 'serialized': '0742000800000042000102', + 'hash_tree_root': '1093b0f1d88b1b2b458196fa860e0df7a7dc1837fe804b95d664279635cb302f' + }, + { + 'value': Shape3(side=0x42, colors=None, radius=None), + 'serialized': '014200', + 'hash_tree_root': '28df3f1c3eebd92504401b155c5cfe2f01c0604889e46ed3d22a3091dde1371f' + }, + { + 'value': Shape3(side=None, colors=[1, 2], radius=None), + 'serialized': '02040000000102', + 'hash_tree_root': '659638368467b2c052ca698fcb65902e9b42ce8e94e1f794dd5296ceac2dec3e' + }, + { + 'value': Shape3(side=None, colors=None, radius=0x42), + 'serialized': '044200', + 'hash_tree_root': 'd585dd0561c718bf4c29e4c1bd7d4efd4a5fe3c45942a7f778acb78fd0b2a4d2' + }, + { + 'value': Shape3(side=None, colors=[1, 2], radius=0x42), + 'serialized': '060600000042000102', + 'hash_tree_root': '00fc0cecc200a415a07372d5d5b8bc7ce49f52504ed3da0336f80a26d811c7bf' + } +] + +for test in stable_container_tests: + assert test['value'].encode_bytes().hex() == test['serialized'] + assert test['value'].hash_tree_root().hex() == test['hash_tree_root'] From 7ab57b9b8900677f5e439a44b60d4cda8c6f27a5 Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Tue, 14 May 2024 21:19:14 +0300 Subject: [PATCH 08/23] Cleanup --- remerkleable/test_impl.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/remerkleable/test_impl.py b/remerkleable/test_impl.py index d72779e..0428ad9 100644 --- a/remerkleable/test_impl.py +++ b/remerkleable/test_impl.py @@ -479,7 +479,7 @@ class B(Container): def test_stable_container(): - # Defines the common merkleization format and a portable serialization format across variants + # Defines the common merkleization format and a portable serialization format class Shape(StableContainer[4]): side: Optional[uint16] color: uint8 @@ -505,15 +505,17 @@ def select_from_base(cls, value: Shape, circle_allowed = False) -> Type[Shape]: return Square assert False - # Compounds + # Defines a container with immutable scheme that contains two `StableContainer` class ShapePair(Container): shape_1: Shape shape_2: Shape + # Inherits merkleization format from `ShapePair`, and serializes more compactly class SquarePair(MerkleizeAs[ShapePair]): shape_1: Square shape_2: Square + # Inherits merkleization format from `ShapePair`, and reorders fields class CirclePair(MerkleizeAs[ShapePair]): shape_2: Circle shape_1: Circle From b1b0a75d447622b3e7ba57b62c81aabc6f28e354 Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Tue, 14 May 2024 23:19:00 +0300 Subject: [PATCH 09/23] Fix indentation of new tests --- remerkleable/test_impl.py | 175 +++++++++++++++++++------------------- 1 file changed, 87 insertions(+), 88 deletions(-) diff --git a/remerkleable/test_impl.py b/remerkleable/test_impl.py index 0428ad9..dea8102 100644 --- a/remerkleable/test_impl.py +++ b/remerkleable/test_impl.py @@ -841,93 +841,92 @@ class ShapeContainerRepr(Container): ), ).hash_tree_root() + # basic container + class Shape1(StableContainer[4]): + side: Optional[uint16] + color: uint8 + radius: Optional[uint16] -# basic container -class Shape1(StableContainer[4]): - side: Optional[uint16] - color: uint8 - radius: Optional[uint16] - -# basic container with different depth -class Shape2(StableContainer[8]): - side: Optional[uint16] - color: uint8 - radius: Optional[uint16] - -# basic container with variable fields -class Shape3(StableContainer[8]): - side: Optional[uint16] - colors: Optional[List[uint8, 4]] - radius: Optional[uint16] - -stable_container_tests = [ - { - 'value': Shape1(side=0x42, color=1, radius=0x42), - 'serialized': '074200014200', - 'hash_tree_root': '37b28eab19bc3e246e55d2e2b2027479454c27ee006d92d4847c84893a162e6d' - }, - { - 'value': Shape1(side=0x42, color=1, radius=None), - 'serialized': '03420001', - 'hash_tree_root': 'bfdb6fda9d02805e640c0f5767b8d1bb9ff4211498a5e2d7c0f36e1b88ce57ff' - }, - { - 'value': Shape1(side=None, color=1, radius=None), - 'serialized': '0201', - 'hash_tree_root': '522edd7309c0041b8eb6a218d756af558e9cf4c816441ec7e6eef42dfa47bb98' - }, - { - 'value': Shape1(side=None, color=1, radius=0x42), - 'serialized': '06014200', - 'hash_tree_root': 'f66d2c38c8d2afbd409e86c529dff728e9a4208215ca20ee44e49c3d11e145d8' - }, - { - 'value': Shape2(side=0x42, color=1, radius=0x42), - 'serialized': '074200014200', - 'hash_tree_root': '0792fb509377ee2ff3b953dd9a88eee11ac7566a8df41c6c67a85bc0b53efa4e' - }, - { - 'value': Shape2(side=0x42, color=1, radius=None), - 'serialized': '03420001', - 'hash_tree_root': 'ddc7acd38ae9d6d6788c14bd7635aeb1d7694768d7e00e1795bb6d328ec14f28' - }, - { - 'value': Shape2(side=None, color=1, radius=None), - 'serialized': '0201', - 'hash_tree_root': '9893ecf9b68030ff23c667a5f2e4a76538a8e2ab48fd060a524888a66fb938c9' - }, - { - 'value': Shape2(side=None, color=1, radius=0x42), - 'serialized': '06014200', - 'hash_tree_root': 'e823471310312d52aa1135d971a3ed72ba041ade3ec5b5077c17a39d73ab17c5' - }, - { - 'value': Shape3(side=0x42, colors=[1, 2], radius=0x42), - 'serialized': '0742000800000042000102', - 'hash_tree_root': '1093b0f1d88b1b2b458196fa860e0df7a7dc1837fe804b95d664279635cb302f' - }, - { - 'value': Shape3(side=0x42, colors=None, radius=None), - 'serialized': '014200', - 'hash_tree_root': '28df3f1c3eebd92504401b155c5cfe2f01c0604889e46ed3d22a3091dde1371f' - }, - { - 'value': Shape3(side=None, colors=[1, 2], radius=None), - 'serialized': '02040000000102', - 'hash_tree_root': '659638368467b2c052ca698fcb65902e9b42ce8e94e1f794dd5296ceac2dec3e' - }, - { - 'value': Shape3(side=None, colors=None, radius=0x42), - 'serialized': '044200', - 'hash_tree_root': 'd585dd0561c718bf4c29e4c1bd7d4efd4a5fe3c45942a7f778acb78fd0b2a4d2' - }, - { - 'value': Shape3(side=None, colors=[1, 2], radius=0x42), - 'serialized': '060600000042000102', - 'hash_tree_root': '00fc0cecc200a415a07372d5d5b8bc7ce49f52504ed3da0336f80a26d811c7bf' - } -] + # basic container with different depth + class Shape2(StableContainer[8]): + side: Optional[uint16] + color: uint8 + radius: Optional[uint16] + + # basic container with variable fields + class Shape3(StableContainer[8]): + side: Optional[uint16] + colors: Optional[List[uint8, 4]] + radius: Optional[uint16] -for test in stable_container_tests: - assert test['value'].encode_bytes().hex() == test['serialized'] - assert test['value'].hash_tree_root().hex() == test['hash_tree_root'] + stable_container_tests = [ + { + 'value': Shape1(side=0x42, color=1, radius=0x42), + 'serialized': '074200014200', + 'hash_tree_root': '37b28eab19bc3e246e55d2e2b2027479454c27ee006d92d4847c84893a162e6d' + }, + { + 'value': Shape1(side=0x42, color=1, radius=None), + 'serialized': '03420001', + 'hash_tree_root': 'bfdb6fda9d02805e640c0f5767b8d1bb9ff4211498a5e2d7c0f36e1b88ce57ff' + }, + { + 'value': Shape1(side=None, color=1, radius=None), + 'serialized': '0201', + 'hash_tree_root': '522edd7309c0041b8eb6a218d756af558e9cf4c816441ec7e6eef42dfa47bb98' + }, + { + 'value': Shape1(side=None, color=1, radius=0x42), + 'serialized': '06014200', + 'hash_tree_root': 'f66d2c38c8d2afbd409e86c529dff728e9a4208215ca20ee44e49c3d11e145d8' + }, + { + 'value': Shape2(side=0x42, color=1, radius=0x42), + 'serialized': '074200014200', + 'hash_tree_root': '0792fb509377ee2ff3b953dd9a88eee11ac7566a8df41c6c67a85bc0b53efa4e' + }, + { + 'value': Shape2(side=0x42, color=1, radius=None), + 'serialized': '03420001', + 'hash_tree_root': 'ddc7acd38ae9d6d6788c14bd7635aeb1d7694768d7e00e1795bb6d328ec14f28' + }, + { + 'value': Shape2(side=None, color=1, radius=None), + 'serialized': '0201', + 'hash_tree_root': '9893ecf9b68030ff23c667a5f2e4a76538a8e2ab48fd060a524888a66fb938c9' + }, + { + 'value': Shape2(side=None, color=1, radius=0x42), + 'serialized': '06014200', + 'hash_tree_root': 'e823471310312d52aa1135d971a3ed72ba041ade3ec5b5077c17a39d73ab17c5' + }, + { + 'value': Shape3(side=0x42, colors=[1, 2], radius=0x42), + 'serialized': '0742000800000042000102', + 'hash_tree_root': '1093b0f1d88b1b2b458196fa860e0df7a7dc1837fe804b95d664279635cb302f' + }, + { + 'value': Shape3(side=0x42, colors=None, radius=None), + 'serialized': '014200', + 'hash_tree_root': '28df3f1c3eebd92504401b155c5cfe2f01c0604889e46ed3d22a3091dde1371f' + }, + { + 'value': Shape3(side=None, colors=[1, 2], radius=None), + 'serialized': '02040000000102', + 'hash_tree_root': '659638368467b2c052ca698fcb65902e9b42ce8e94e1f794dd5296ceac2dec3e' + }, + { + 'value': Shape3(side=None, colors=None, radius=0x42), + 'serialized': '044200', + 'hash_tree_root': 'd585dd0561c718bf4c29e4c1bd7d4efd4a5fe3c45942a7f778acb78fd0b2a4d2' + }, + { + 'value': Shape3(side=None, colors=[1, 2], radius=0x42), + 'serialized': '060600000042000102', + 'hash_tree_root': '00fc0cecc200a415a07372d5d5b8bc7ce49f52504ed3da0336f80a26d811c7bf' + } + ] + + for test in stable_container_tests: + assert test['value'].encode_bytes().hex() == test['serialized'] + assert test['value'].hash_tree_root().hex() == test['hash_tree_root'] From ba2c8bb83295acc1964578c71826d2b8ef170f94 Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Tue, 14 May 2024 23:39:47 +0300 Subject: [PATCH 10/23] Update variable names for 'merkleizeas' --- remerkleable/test_impl.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/remerkleable/test_impl.py b/remerkleable/test_impl.py index dea8102..372e0ab 100644 --- a/remerkleable/test_impl.py +++ b/remerkleable/test_impl.py @@ -547,7 +547,7 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap # Square tests square_bytes_stable = bytes.fromhex("03420001") - square_bytes_variant = bytes.fromhex("420001") + square_bytes_merkleizeas = bytes.fromhex("420001") square_root = ShapeRepr( value=ShapePayload(side=0x42, color=1, radius=0), active_fields=Bitvector[4](True, True, False, False), @@ -561,10 +561,10 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert len(set(shapes)) == 1 assert len(set(squares)) == 1 assert all(shape.encode_bytes() == square_bytes_stable for shape in shapes) - assert all(square.encode_bytes() == square_bytes_variant for square in squares) + assert all(square.encode_bytes() == square_bytes_merkleizeas for square in squares) assert ( Square(backing=Shape.decode_bytes(square_bytes_stable).get_backing()) == - Square.decode_bytes(square_bytes_variant) == + Square.decode_bytes(square_bytes_merkleizeas) == AnyShape.decode_bytes(square_bytes_stable) == AnyShape.decode_bytes(square_bytes_stable, circle_allowed = True) ) @@ -592,7 +592,7 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap for square in squares: square.side = 0x1337 square_bytes_stable = bytes.fromhex("03371301") - square_bytes_variant = bytes.fromhex("371301") + square_bytes_merkleizeas = bytes.fromhex("371301") square_root = ShapeRepr( value=ShapePayload(side=0x1337, color=1, radius=0), active_fields=Bitvector[4](True, True, False, False), @@ -600,10 +600,10 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert len(set(shapes)) == 1 assert len(set(squares)) == 1 assert all(shape.encode_bytes() == square_bytes_stable for shape in shapes) - assert all(square.encode_bytes() == square_bytes_variant for square in squares) + assert all(square.encode_bytes() == square_bytes_merkleizeas for square in squares) assert ( Square(backing=Shape.decode_bytes(square_bytes_stable).get_backing()) == - Square.decode_bytes(square_bytes_variant) == + Square.decode_bytes(square_bytes_merkleizeas) == AnyShape.decode_bytes(square_bytes_stable) == AnyShape.decode_bytes(square_bytes_stable, circle_allowed = True) ) @@ -624,7 +624,7 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap # Circle tests circle_bytes_stable = bytes.fromhex("06014200") - circle_bytes_variant = bytes.fromhex("420001") + circle_bytes_merkleizeas = bytes.fromhex("420001") circle_root = ShapeRepr( value=ShapePayload(side=0, color=1, radius=0x42), active_fields=Bitvector[4](False, True, True, False), @@ -641,10 +641,10 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert len(set(shapes)) == 1 assert len(set(circles)) == 1 assert all(shape.encode_bytes() == circle_bytes_stable for shape in shapes) - assert all(circle.encode_bytes() == circle_bytes_variant for circle in circles) + assert all(circle.encode_bytes() == circle_bytes_merkleizeas for circle in circles) assert ( Circle(backing=Shape.decode_bytes(circle_bytes_stable).get_backing()) == - Circle.decode_bytes(circle_bytes_variant) == + Circle.decode_bytes(circle_bytes_merkleizeas) == AnyShape.decode_bytes(circle_bytes_stable, circle_allowed = True) ) assert all(shape.hash_tree_root() == circle_root for shape in shapes) @@ -674,7 +674,7 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap # SquarePair tests square_pair_bytes_stable = bytes.fromhex("080000000c0000000342000103690001") - square_pair_bytes_variant = bytes.fromhex("420001690001") + square_pair_bytes_merkleizeas = bytes.fromhex("420001690001") square_pair_root = ShapePairRepr( shape_1=ShapeRepr( value=ShapePayload(side=0x42, color=1, radius=0), @@ -700,10 +700,10 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert len(set(shape_pairs)) == 1 assert len(set(square_pairs)) == 1 assert all(pair.encode_bytes() == square_pair_bytes_stable for pair in shape_pairs) - assert all(pair.encode_bytes() == square_pair_bytes_variant for pair in square_pairs) + assert all(pair.encode_bytes() == square_pair_bytes_merkleizeas for pair in square_pairs) assert ( SquarePair(backing=ShapePair.decode_bytes(square_pair_bytes_stable).get_backing()) == - SquarePair.decode_bytes(square_pair_bytes_variant) == + SquarePair.decode_bytes(square_pair_bytes_merkleizeas) == AnyShapePair.decode_bytes(square_pair_bytes_stable) == AnyShapePair.decode_bytes(square_pair_bytes_stable, circle_allowed = True) ) @@ -712,7 +712,7 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap # CirclePair tests circle_pair_bytes_stable = bytes.fromhex("080000000c0000000601420006016900") - circle_pair_bytes_variant = bytes.fromhex("690001420001") + circle_pair_bytes_merkleizeas = bytes.fromhex("690001420001") circle_pair_root = ShapePairRepr( shape_1=ShapeRepr( value=ShapePayload(side=0, color=1, radius=0x42), @@ -738,10 +738,10 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert len(set(shape_pairs)) == 1 assert len(set(circle_pairs)) == 1 assert all(pair.encode_bytes() == circle_pair_bytes_stable for pair in shape_pairs) - assert all(pair.encode_bytes() == circle_pair_bytes_variant for pair in circle_pairs) + assert all(pair.encode_bytes() == circle_pair_bytes_merkleizeas for pair in circle_pairs) assert ( CirclePair(backing=ShapePair.decode_bytes(circle_pair_bytes_stable).get_backing()) == - CirclePair.decode_bytes(circle_pair_bytes_variant) == + CirclePair.decode_bytes(circle_pair_bytes_merkleizeas) == AnyShapePair.decode_bytes(circle_pair_bytes_stable, circle_allowed = True) ) assert all(pair.hash_tree_root() == circle_pair_root for pair in shape_pairs) From 53455dd2b1957325f779b34f3b84e5ff8944a6dc Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Thu, 16 May 2024 16:14:13 +0300 Subject: [PATCH 11/23] Rename `MerkleizeAs` > `Profile` --- remerkleable/stable_container.py | 18 +++++++------- remerkleable/test_impl.py | 40 ++++++++++++++++---------------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index c5405de..2ce5ebe 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -262,7 +262,7 @@ def serialize(self, stream: BinaryIO) -> int: return num_prefix_bytes + num_data_bytes -class MerkleizeAs(ComplexView): +class Profile(ComplexView): _o: int def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None, **kwargs): @@ -293,15 +293,15 @@ def __init_subclass__(cls, *args, **kwargs): cls._o += 1 assert cls._o == 0 or issubclass(cls.B, StableContainer) - def __class_getitem__(cls, b) -> Type["MerkleizeAs"]: + def __class_getitem__(cls, b) -> Type["Profile"]: if not issubclass(b, StableContainer) and not issubclass(b, Container): - raise Exception(f"invalid MerkleizeAs base: {b}") + raise Exception(f"invalid Profile base: {b}") - class MerkleizeAsView(MerkleizeAs, b): + class ProfileView(Profile, b): B = b - MerkleizeAsView.__name__ = MerkleizeAsView.type_repr() - return MerkleizeAsView + ProfileView.__name__ = ProfileView.type_repr() + return ProfileView @classmethod def fields(cls) -> Dict[str, Tuple[Type[View], bool]]: @@ -321,7 +321,7 @@ def type_byte_length(cls) -> int: if cls.is_fixed_byte_length(): return cls.min_byte_length() else: - raise Exception("dynamic length MerkleizeAs does not have a fixed byte length") + raise Exception("dynamic length Profile does not have a fixed byte length") @classmethod def min_byte_length(cls) -> int: @@ -427,14 +427,14 @@ def __repr__(self): @classmethod def type_repr(cls) -> str: - return f"MerkleizeAs[{cls.B.__name__}]" + return f"Profile[{cls.B.__name__}]" @classmethod def deserialize(cls: Type[B], stream: BinaryIO, scope: int) -> B: if cls._o > 0: num_prefix_bytes = Bitvector[cls._o].type_byte_length() if scope < num_prefix_bytes: - raise ValueError("scope too small, cannot read MerkleizeAs optional fields") + raise ValueError("scope too small, cannot read Profile optional fields") optional_fields = Bitvector[cls._o].deserialize(stream, num_prefix_bytes) scope = scope - num_prefix_bytes diff --git a/remerkleable/test_impl.py b/remerkleable/test_impl.py index 372e0ab..67a0670 100644 --- a/remerkleable/test_impl.py +++ b/remerkleable/test_impl.py @@ -9,7 +9,7 @@ from remerkleable.bitfields import Bitvector, Bitlist from remerkleable.byte_arrays import ByteVector, ByteList from remerkleable.core import View, ObjType -from remerkleable.stable_container import OneOf, StableContainer, MerkleizeAs +from remerkleable.stable_container import OneOf, Profile, StableContainer from remerkleable.union import Union from hashlib import sha256 @@ -486,12 +486,12 @@ class Shape(StableContainer[4]): radius: Optional[uint16] # Inherits merkleization format from `Shape`, but is serialized more compactly - class Square(MerkleizeAs[Shape]): + class Square(Profile[Shape]): side: uint16 color: uint8 # Inherits merkleization format from `Shape`, but is serialized more compactly - class Circle(MerkleizeAs[Shape]): + class Circle(Profile[Shape]): radius: uint16 color: uint8 @@ -511,12 +511,12 @@ class ShapePair(Container): shape_2: Shape # Inherits merkleization format from `ShapePair`, and serializes more compactly - class SquarePair(MerkleizeAs[ShapePair]): + class SquarePair(Profile[ShapePair]): shape_1: Square shape_2: Square # Inherits merkleization format from `ShapePair`, and reorders fields - class CirclePair(MerkleizeAs[ShapePair]): + class CirclePair(Profile[ShapePair]): shape_2: Circle shape_1: Circle @@ -547,7 +547,7 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap # Square tests square_bytes_stable = bytes.fromhex("03420001") - square_bytes_merkleizeas = bytes.fromhex("420001") + square_bytes_profile = bytes.fromhex("420001") square_root = ShapeRepr( value=ShapePayload(side=0x42, color=1, radius=0), active_fields=Bitvector[4](True, True, False, False), @@ -561,10 +561,10 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert len(set(shapes)) == 1 assert len(set(squares)) == 1 assert all(shape.encode_bytes() == square_bytes_stable for shape in shapes) - assert all(square.encode_bytes() == square_bytes_merkleizeas for square in squares) + assert all(square.encode_bytes() == square_bytes_profile for square in squares) assert ( Square(backing=Shape.decode_bytes(square_bytes_stable).get_backing()) == - Square.decode_bytes(square_bytes_merkleizeas) == + Square.decode_bytes(square_bytes_profile) == AnyShape.decode_bytes(square_bytes_stable) == AnyShape.decode_bytes(square_bytes_stable, circle_allowed = True) ) @@ -592,7 +592,7 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap for square in squares: square.side = 0x1337 square_bytes_stable = bytes.fromhex("03371301") - square_bytes_merkleizeas = bytes.fromhex("371301") + square_bytes_profile = bytes.fromhex("371301") square_root = ShapeRepr( value=ShapePayload(side=0x1337, color=1, radius=0), active_fields=Bitvector[4](True, True, False, False), @@ -600,10 +600,10 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert len(set(shapes)) == 1 assert len(set(squares)) == 1 assert all(shape.encode_bytes() == square_bytes_stable for shape in shapes) - assert all(square.encode_bytes() == square_bytes_merkleizeas for square in squares) + assert all(square.encode_bytes() == square_bytes_profile for square in squares) assert ( Square(backing=Shape.decode_bytes(square_bytes_stable).get_backing()) == - Square.decode_bytes(square_bytes_merkleizeas) == + Square.decode_bytes(square_bytes_profile) == AnyShape.decode_bytes(square_bytes_stable) == AnyShape.decode_bytes(square_bytes_stable, circle_allowed = True) ) @@ -624,7 +624,7 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap # Circle tests circle_bytes_stable = bytes.fromhex("06014200") - circle_bytes_merkleizeas = bytes.fromhex("420001") + circle_bytes_profile = bytes.fromhex("420001") circle_root = ShapeRepr( value=ShapePayload(side=0, color=1, radius=0x42), active_fields=Bitvector[4](False, True, True, False), @@ -641,10 +641,10 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert len(set(shapes)) == 1 assert len(set(circles)) == 1 assert all(shape.encode_bytes() == circle_bytes_stable for shape in shapes) - assert all(circle.encode_bytes() == circle_bytes_merkleizeas for circle in circles) + assert all(circle.encode_bytes() == circle_bytes_profile for circle in circles) assert ( Circle(backing=Shape.decode_bytes(circle_bytes_stable).get_backing()) == - Circle.decode_bytes(circle_bytes_merkleizeas) == + Circle.decode_bytes(circle_bytes_profile) == AnyShape.decode_bytes(circle_bytes_stable, circle_allowed = True) ) assert all(shape.hash_tree_root() == circle_root for shape in shapes) @@ -674,7 +674,7 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap # SquarePair tests square_pair_bytes_stable = bytes.fromhex("080000000c0000000342000103690001") - square_pair_bytes_merkleizeas = bytes.fromhex("420001690001") + square_pair_bytes_profile = bytes.fromhex("420001690001") square_pair_root = ShapePairRepr( shape_1=ShapeRepr( value=ShapePayload(side=0x42, color=1, radius=0), @@ -700,10 +700,10 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert len(set(shape_pairs)) == 1 assert len(set(square_pairs)) == 1 assert all(pair.encode_bytes() == square_pair_bytes_stable for pair in shape_pairs) - assert all(pair.encode_bytes() == square_pair_bytes_merkleizeas for pair in square_pairs) + assert all(pair.encode_bytes() == square_pair_bytes_profile for pair in square_pairs) assert ( SquarePair(backing=ShapePair.decode_bytes(square_pair_bytes_stable).get_backing()) == - SquarePair.decode_bytes(square_pair_bytes_merkleizeas) == + SquarePair.decode_bytes(square_pair_bytes_profile) == AnyShapePair.decode_bytes(square_pair_bytes_stable) == AnyShapePair.decode_bytes(square_pair_bytes_stable, circle_allowed = True) ) @@ -712,7 +712,7 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap # CirclePair tests circle_pair_bytes_stable = bytes.fromhex("080000000c0000000601420006016900") - circle_pair_bytes_merkleizeas = bytes.fromhex("690001420001") + circle_pair_bytes_profile = bytes.fromhex("690001420001") circle_pair_root = ShapePairRepr( shape_1=ShapeRepr( value=ShapePayload(side=0, color=1, radius=0x42), @@ -738,10 +738,10 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert len(set(shape_pairs)) == 1 assert len(set(circle_pairs)) == 1 assert all(pair.encode_bytes() == circle_pair_bytes_stable for pair in shape_pairs) - assert all(pair.encode_bytes() == circle_pair_bytes_merkleizeas for pair in circle_pairs) + assert all(pair.encode_bytes() == circle_pair_bytes_profile for pair in circle_pairs) assert ( CirclePair(backing=ShapePair.decode_bytes(circle_pair_bytes_stable).get_backing()) == - CirclePair.decode_bytes(circle_pair_bytes_merkleizeas) == + CirclePair.decode_bytes(circle_pair_bytes_profile) == AnyShapePair.decode_bytes(circle_pair_bytes_stable, circle_allowed = True) ) assert all(pair.hash_tree_root() == circle_pair_root for pair in shape_pairs) From 0c481d5ffee22799c9c27c7993e74fb4e7d4cc83 Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Thu, 16 May 2024 16:18:08 +0300 Subject: [PATCH 12/23] Remove tests that reorder fields as the feature got removed --- remerkleable/test_impl.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/remerkleable/test_impl.py b/remerkleable/test_impl.py index 67a0670..5600acf 100644 --- a/remerkleable/test_impl.py +++ b/remerkleable/test_impl.py @@ -492,8 +492,8 @@ class Square(Profile[Shape]): # Inherits merkleization format from `Shape`, but is serialized more compactly class Circle(Profile[Shape]): - radius: uint16 color: uint8 + radius: uint16 class AnyShape(OneOf[Shape]): @classmethod @@ -515,10 +515,10 @@ class SquarePair(Profile[ShapePair]): shape_1: Square shape_2: Square - # Inherits merkleization format from `ShapePair`, and reorders fields + # Inherits merkleization format from `ShapePair`, and serializes more compactly class CirclePair(Profile[ShapePair]): - shape_2: Circle shape_1: Circle + shape_2: Circle # Helper containers for merkleization testing class ShapePayload(Container): @@ -624,7 +624,7 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap # Circle tests circle_bytes_stable = bytes.fromhex("06014200") - circle_bytes_profile = bytes.fromhex("420001") + circle_bytes_profile = bytes.fromhex("014200") circle_root = ShapeRepr( value=ShapePayload(side=0, color=1, radius=0x42), active_fields=Bitvector[4](False, True, True, False), @@ -712,7 +712,7 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap # CirclePair tests circle_pair_bytes_stable = bytes.fromhex("080000000c0000000601420006016900") - circle_pair_bytes_profile = bytes.fromhex("690001420001") + circle_pair_bytes_profile = bytes.fromhex("014200016900") circle_pair_root = ShapePairRepr( shape_1=ShapeRepr( value=ShapePayload(side=0, color=1, radius=0x42), @@ -823,7 +823,7 @@ class ShapeContainerRepr(Container): square=Square(side=0x42, color=1), circle=Circle(radius=0x42, color=1), ) - container_bytes = bytes.fromhex("0a000000420001420001074200014200") + container_bytes = bytes.fromhex("0a000000420001014200074200014200") assert container.encode_bytes() == container_bytes assert ShapeContainer.decode_bytes(container_bytes) == container assert container.hash_tree_root() == ShapeContainerRepr( From c933754450cc669da0c0d17c7fae07735fdaca81 Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Tue, 21 May 2024 14:34:43 +0200 Subject: [PATCH 13/23] Add path support (`gindex` lookup by key) --- remerkleable/stable_container.py | 40 +++++++++++++++++++++-- remerkleable/test_typing.py | 55 ++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 3 deletions(-) diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index 2ce5ebe..a19e61c 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -2,15 +2,16 @@ # The EIP is still under review, functionality may change or go away without deprecation. import io -from typing import BinaryIO, Dict, List as PyList, Optional, Tuple, TypeVar, Type, Union as PyUnion, \ +from typing import Any, BinaryIO, Dict, List as PyList, Optional, Tuple, TypeVar, Type, Union as PyUnion, \ get_args, get_origin from textwrap import indent from remerkleable.bitfields import Bitvector from remerkleable.complex import ComplexView, Container, FieldOffset, \ decode_offset, encode_offset from remerkleable.core import View, ViewHook, OFFSET_BYTE_LENGTH -from remerkleable.tree import NavigationError, Node, PairNode, \ - get_depth, subtree_fill_to_contents, zero_node +from remerkleable.tree import Gindex, NavigationError, Node, PairNode, \ + get_depth, subtree_fill_to_contents, zero_node, \ + RIGHT_GINDEX N = TypeVar('N') B = TypeVar('B', bound="ComplexView") @@ -261,6 +262,22 @@ def serialize(self, stream: BinaryIO) -> int: return num_prefix_bytes + num_data_bytes + @classmethod + def navigate_type(cls, key: Any) -> Type[View]: + if key == '__active_fields__': + return Bitvector[cls.N] + (_, ftyp, fopt) = cls._field_indices[key] + if fopt: + return Optional[ftyp] + return ftyp + + @classmethod + def key_to_static_gindex(cls, key: Any) -> Gindex: + if key == '__active_fields__': + return RIGHT_GINDEX + (findex, _, _) = cls._field_indices[key] + return 2**get_depth(cls.N) * 2 + findex + class Profile(ComplexView): _o: int @@ -526,6 +543,23 @@ def serialize(self, stream: BinaryIO) -> int: return num_prefix_bytes + num_data_bytes + @classmethod + def navigate_type(cls, key: Any) -> Type[View]: + if key == '__active_fields__': + return Bitvector[cls.B.N] + (ftyp, fopt) = cls.fields()[key] + if fopt: + return Optional[ftyp] + return ftyp + + @classmethod + def key_to_static_gindex(cls, key: Any) -> Gindex: + if key == '__active_fields__': + return RIGHT_GINDEX + (_, _) = cls.fields()[key] + (findex, _, _) = cls.B._field_indices[key] + return 2**get_depth(cls.N) * 2 + findex + class OneOf(ComplexView): def __class_getitem__(cls, b) -> Type["OneOf"]: diff --git a/remerkleable/test_typing.py b/remerkleable/test_typing.py index 6862414..0094349 100644 --- a/remerkleable/test_typing.py +++ b/remerkleable/test_typing.py @@ -2,6 +2,7 @@ import pytest # type: ignore +from typing import Optional from random import Random from remerkleable.complex import Container, Vector, List @@ -10,6 +11,7 @@ from remerkleable.bitfields import Bitvector, Bitlist from remerkleable.byte_arrays import ByteVector, Bytes1, Bytes4, Bytes8, Bytes32, Bytes48, Bytes96 from remerkleable.core import BasicView, View +from remerkleable.stable_container import Profile, StableContainer from remerkleable.union import Union from remerkleable.tree import get_depth, merkle_hash, LEFT_GINDEX, RIGHT_GINDEX @@ -362,6 +364,59 @@ class Wrapper(Container): except KeyError: pass + class StableFields(StableContainer[8]): + foo: Optional[uint32] + bar: Optional[uint64] + quix: Optional[uint64] + more: Optional[uint32] + + class FooFields(Profile[StableFields]): + foo: uint32 + more: Optional[uint32] + + class BarFields(Profile[StableFields]): + bar: uint64 + quix: uint64 + more: Optional[uint32] + + assert issubclass((StableFields / '__active_fields__').navigate_type(), Bitvector) + assert (StableFields / '__active_fields__').navigate_type().vector_length() == 8 + assert (StableFields / '__active_fields__').gindex() == 0b11 + assert (StableFields / 'foo').navigate_type() == Optional[uint32] + assert (StableFields / 'foo').gindex() == 0b10000 + assert (StableFields / 'bar').navigate_type() == Optional[uint64] + assert (StableFields / 'bar').gindex() == 0b10001 + assert (StableFields / 'quix').navigate_type() == Optional[uint64] + assert (StableFields / 'quix').gindex() == 0b10010 + assert (StableFields / 'more').navigate_type() == Optional[uint32] + assert (StableFields / 'more').gindex() == 0b10011 + + assert issubclass((FooFields / '__active_fields__').navigate_type(), Bitvector) + assert (FooFields / '__active_fields__').navigate_type().vector_length() == 8 + assert (FooFields / '__active_fields__').gindex() == 0b11 + assert (FooFields / 'foo').navigate_type() == uint32 + assert (FooFields / 'foo').gindex() == 0b10000 + assert (FooFields / 'more').navigate_type() == Optional[uint32] + assert (FooFields / 'more').gindex() == 0b10011 + try: + (FooFields / 'bar').navigate_type() + assert False + except KeyError: + pass + + assert issubclass((BarFields / '__active_fields__').navigate_type(), Bitvector) + assert (BarFields / '__active_fields__').navigate_type().vector_length() == 8 + assert (BarFields / '__active_fields__').gindex() == 0b11 + assert (BarFields / 'bar').navigate_type() == uint64 + assert (BarFields / 'bar').gindex() == 0b10001 + assert (BarFields / 'more').navigate_type() == Optional[uint32] + assert (BarFields / 'more').gindex() == 0b10011 + try: + (BarFields / 'foo').navigate_type() + assert False + except KeyError: + pass + def test_bitvector(): for size in [1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 511, 512, 513, 1023, 1024, 1025]: From ff2b81d407d3544ef11a3570b19b26e48b75ded0 Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Thu, 23 May 2024 14:17:02 +0200 Subject: [PATCH 14/23] Allow fields of name `N` in `StableContainer` and `B` in `Profile` --- remerkleable/stable_container.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index a19e61c..65055e0 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -137,6 +137,11 @@ def active_fields(self) -> Bitvector: active_fields_node = super().get_backing().get_right() return Bitvector[self.__class__.N].view_from_backing(active_fields_node) + def __getattribute__(self, item): + if item == 'N': + raise AttributeError(f"use .__class__.{item} to access {item}") + return object.__getattribute__(self, item) + def __getattr__(self, item): if item[0] == '_': return super().__getattribute__(item) @@ -378,6 +383,11 @@ def optional_fields(self) -> Bitvector: oindex += 1 return optional_fields + def __getattribute__(self, item): + if item == 'B': + raise AttributeError(f"use .__class__.{item} to access {item}") + return object.__getattribute__(self, item) + def __getattr__(self, item): if item[0] == '_': return super().__getattribute__(item) From 12ed1983527dde9ad6355e22b5635ee626aed692 Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Thu, 23 May 2024 21:09:21 +0200 Subject: [PATCH 15/23] Fix `Profile` should not inherit from `StableContainer` --- remerkleable/stable_container.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index 65055e0..d3523e3 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -319,7 +319,7 @@ def __class_getitem__(cls, b) -> Type["Profile"]: if not issubclass(b, StableContainer) and not issubclass(b, Container): raise Exception(f"invalid Profile base: {b}") - class ProfileView(Profile, b): + class ProfileView(Profile): B = b ProfileView.__name__ = ProfileView.type_repr() @@ -567,8 +567,13 @@ def key_to_static_gindex(cls, key: Any) -> Gindex: if key == '__active_fields__': return RIGHT_GINDEX (_, _) = cls.fields()[key] - (findex, _, _) = cls.B._field_indices[key] - return 2**get_depth(cls.N) * 2 + findex + if issubclass(cls.B, StableContainer): + (findex, _, _) = cls.B._field_indices[key] + return 2**get_depth(cls.B.N) * 2 + findex + else: + findex = cls.B._field_indices[key] + n = len(cls.B.fields()) + return 2**get_depth(n) + findex class OneOf(ComplexView): From ee17bb095998598f97bf239809f5b4f8c2df81f5 Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Sat, 1 Jun 2024 23:19:49 +0200 Subject: [PATCH 16/23] Require fields to be `Optional` in `StableContainer` --- remerkleable/stable_container.py | 191 +++++++++++++++++-------------- remerkleable/test_impl.py | 6 +- 2 files changed, 106 insertions(+), 91 deletions(-) diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index d3523e3..32054eb 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -2,20 +2,21 @@ # The EIP is still under review, functionality may change or go away without deprecation. import io -from typing import Any, BinaryIO, Dict, List as PyList, Optional, Tuple, TypeVar, Type, Union as PyUnion, \ +from typing import Any, BinaryIO, Dict, List as PyList, Optional, Tuple, \ + TypeVar, Type, Union as PyUnion, \ get_args, get_origin from textwrap import indent from remerkleable.bitfields import Bitvector from remerkleable.complex import ComplexView, Container, FieldOffset, \ decode_offset, encode_offset -from remerkleable.core import View, ViewHook, OFFSET_BYTE_LENGTH +from remerkleable.core import View, ViewHook, ViewMeta, OFFSET_BYTE_LENGTH from remerkleable.tree import Gindex, NavigationError, Node, PairNode, \ get_depth, subtree_fill_to_contents, zero_node, \ RIGHT_GINDEX -N = TypeVar('N') -B = TypeVar('B', bound="ComplexView") -S = TypeVar('S', bound="ComplexView") +N = TypeVar('N', bound=int) +B = TypeVar('B', bound='ComplexView') +S = TypeVar('S', bound='ComplexView') def all_fields(cls) -> Dict[str, Tuple[Type[View], bool]]: @@ -38,37 +39,31 @@ def field_val_repr(self, fkey: str, ftyp: Type[View], fopt: bool) -> str: field_repr = field_repr[:i+1] + indent(field_repr[i+1:], ' ' * len(field_start)) return field_start + field_repr except NavigationError: - return f"{field_start} *omitted*" + return f'{field_start} *omitted*' def repr(self) -> str: - return f"{self.__class__.type_repr()}:\n" + '\n'.join( + return f'{self.__class__.type_repr()}:\n' + '\n'.join( indent(field_val_repr(self, fkey, ftyp, fopt), ' ') for fkey, (ftyp, fopt) in self.__class__.fields().items()) class StableContainer(ComplexView): - _field_indices: Dict[str, Tuple[int, Type[View], bool]] - __slots__ = '_field_indices' + __slots__ = '_field_indices', 'N' + _field_indices: Dict[str, Tuple[int, Type[View]]] + N: int def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None, **kwargs): if backing is not None: if len(kwargs) != 0: - raise Exception("cannot have both a backing and elements to init fields") + raise Exception('Cannot have both a backing and elements to init fields') return super().__new__(cls, backing=backing, hook=hook, **kwargs) - for fkey, (ftyp, fopt) in cls.fields().items(): - if fkey not in kwargs: - if not fopt: - raise AttributeError(f"Field '{fkey}' is required in {cls}") - kwargs[fkey] = None - input_nodes = [] active_fields = Bitvector[cls.N]() - for findex, (fkey, (ftyp, fopt)) in enumerate(cls.fields().items()): + for fkey, (findex, ftyp) in cls._field_indices.items(): fnode: Node - assert fkey in kwargs - finput = kwargs.pop(fkey) + finput = kwargs.pop(fkey) if fkey in kwargs else None if finput is None: fnode = zero_node(0) active_fields.set(findex, False) @@ -79,35 +74,61 @@ def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None fnode = ftyp.coerce_view(finput).get_backing() active_fields.set(findex, True) input_nodes.append(fnode) - if len(kwargs) > 0: - raise AttributeError(f'The field names [{"".join(kwargs.keys())}] are not defined in {cls}') + raise AttributeError(f'Fields [{''.join(kwargs.keys())}] unknown in `{cls.__name__}`') backing = PairNode( left=subtree_fill_to_contents(input_nodes, get_depth(cls.N)), - right=active_fields.get_backing()) + right=active_fields.get_backing(), + ) return super().__new__(cls, backing=backing, hook=hook, **kwargs) - def __init_subclass__(cls, *args, **kwargs): - super().__init_subclass__(*args, **kwargs) - cls._field_indices = { - fkey: (i, ftyp, fopt) - for i, (fkey, (ftyp, fopt)) in enumerate(cls.fields().items()) - } - - def __class_getitem__(cls, n) -> Type["StableContainer"]: + def __init_subclass__(cls, **kwargs): + if 'n' not in kwargs: + raise TypeError(f'Missing capacity: `{cls.__name__}(StableContainer)`') + n = kwargs.pop('n') + if not isinstance(n, int): + raise TypeError(f'Invalid capacity: `{cls.__name__}(StableContainer[{n}])`') if n <= 0: - raise Exception(f"invalid stablecontainer capacity: {n}") - - class StableContainerView(StableContainer): - N = n + raise TypeError(f'Unsupported capacity: `{cls.__name__}(StableContainer[{n}])`') + cls.N = n + + def __class_getitem__(cls, n: int) -> Type['StableContainer']: + class StableContainerMeta(ViewMeta): + def __new__(cls, name, bases, dct): + return super().__new__(cls, name, bases, dct, n=n) + + class StableContainerView(StableContainer, metaclass=StableContainerMeta): + def __init_subclass__(cls, *args, **kwargs): + if 'N' in cls.__dict__: + raise TypeError(f'Cannot override `N` inside `{cls.__name__}`') + cls._field_indices = {} + for findex, (fkey, v) in enumerate(cls.__annotations__.items()): + if fkey in cls._field_indices: + raise TypeError(f'Field `{fkey}` defined multiple times in `{cls.__name}`') + if ( + get_origin(v) != PyUnion + or len(get_args(v)) != 2 + or type(None) not in get_args(v) + ): + raise TypeError( + f'`StableContainer` fields must be `Optional[T]` ' + f'but `{cls.__name__}.{fkey}` has type `{v.__name__}`' + ) + ftyp = get_args(v)[0] if get_args(v)[0] is not type(None) else get_args(v)[1] + cls._field_indices[fkey] = (findex, ftyp) + if len(cls._field_indices) > cls.N: + raise TypeError( + f'`{cls.__name__}` is `StableContainer[{cls.N}]` ' + f'but contains {len(cls._field_indices)} fields' + ) StableContainerView.__name__ = StableContainerView.type_repr() return StableContainerView @classmethod - def fields(cls) -> Dict[str, Tuple[Type[View], bool]]: - return all_fields(cls) + def fields(cls) -> Dict[str, Type[View]]: + { fkey: ftyp for fkey, (_, ftyp) in cls._field_indices } @classmethod def is_fixed_byte_length(cls) -> bool: @@ -115,19 +136,12 @@ def is_fixed_byte_length(cls) -> bool: @classmethod def min_byte_length(cls) -> int: - total = Bitvector[cls.N].type_byte_length() - for _, (ftyp, fopt) in cls.fields().items(): - if fopt: - continue - if not ftyp.is_fixed_byte_length(): - total += OFFSET_BYTE_LENGTH - total += ftyp.min_byte_length() - return total + return Bitvector[cls.N].type_byte_length() @classmethod def max_byte_length(cls) -> int: total = Bitvector[cls.N].type_byte_length() - for _, (ftyp, _) in cls.fields().items(): + for (_, ftyp) in cls._field_indices.values(): if not ftyp.is_fixed_byte_length(): total += OFFSET_BYTE_LENGTH total += ftyp.max_byte_length() @@ -139,7 +153,7 @@ def active_fields(self) -> Bitvector: def __getattribute__(self, item): if item == 'N': - raise AttributeError(f"use .__class__.{item} to access {item}") + raise AttributeError(f'Use `.__class__.{item}` to access `{item}`') return object.__getattribute__(self, item) def __getattr__(self, item): @@ -147,14 +161,12 @@ def __getattr__(self, item): return super().__getattribute__(item) else: try: - (findex, ftyp, fopt) = self.__class__._field_indices[item] + (findex, ftyp) = self.__class__._field_indices[item] except KeyError: - raise AttributeError(f"unknown attribute {item}") + raise AttributeError(f'Unknown field `{item}`') if not self.active_fields().get(findex): - assert fopt return None - data = super().get_backing().get_left() fnode = data.getter(2**get_depth(self.__class__.N) + findex) return ftyp.view_from_backing(fnode) @@ -164,13 +176,12 @@ def __setattr__(self, key, value): super().__setattr__(key, value) else: try: - (findex, ftyp, fopt) = self.__class__._field_indices[key] + (findex, ftyp) = self.__class__._field_indices[key] except KeyError: - raise AttributeError(f"unknown attribute {key}") + raise AttributeError(f'Unknown field `{key}`') next_backing = self.get_backing() - assert value is not None or fopt active_fields = self.active_fields() active_fields.set(findex, value is not None) next_backing = next_backing.rebind_right(active_fields.get_backing()) @@ -193,24 +204,25 @@ def __repr__(self): @classmethod def type_repr(cls) -> str: - return f"StableContainer[{cls.N}]" + return f'StableContainer[{cls.N}]' @classmethod def deserialize(cls: Type[S], stream: BinaryIO, scope: int) -> S: num_prefix_bytes = Bitvector[cls.N].type_byte_length() if scope < num_prefix_bytes: - raise ValueError("scope too small, cannot read StableContainer active fields") + raise ValueError(f'Scope too small, cannot read `StableContainer[{cls.N}]` active fields') active_fields = Bitvector[cls.N].deserialize(stream, num_prefix_bytes) scope = scope - num_prefix_bytes - max_findex = 0 - field_values: Dict[str, Optional[View]] = {} + for findex in range(len(cls._field_indices), cls.N): + if active_fields.get(findex): + raise Exception(f'Unknown field index {findex}') + + field_values: Dict[str, View] = {} dyn_fields: PyList[FieldOffset] = [] fixed_size = 0 - for findex, (fkey, (ftyp, _)) in enumerate(cls.fields().items()): - max_findex = findex + for fkey, (findex, ftyp) in cls._field_indices.items(): if not active_fields.get(findex): - field_values[fkey] = None continue if ftyp.is_fixed_byte_length(): fsize = ftyp.type_byte_length() @@ -222,37 +234,41 @@ def deserialize(cls: Type[S], stream: BinaryIO, scope: int) -> S: fixed_size += OFFSET_BYTE_LENGTH if len(dyn_fields) > 0: if dyn_fields[0].offset < fixed_size: - raise Exception(f"first offset {dyn_fields[0].offset} is " - f"smaller than expected fixed size {fixed_size}") + raise Exception(f'First offset {dyn_fields[0].offset} is ' + f'smaller than expected fixed size {fixed_size}') for i, (fkey, ftyp, foffset) in enumerate(dyn_fields): next_offset = dyn_fields[i + 1].offset if i + 1 < len(dyn_fields) else scope if foffset > next_offset: - raise Exception(f"offset {i} is invalid: {foffset} " - f"larger than next offset {next_offset}") + raise Exception(f'Offset {i} is invalid: {foffset} ' + f'larger than next offset {next_offset}') fsize = next_offset - foffset f_min_size, f_max_size = ftyp.min_byte_length(), ftyp.max_byte_length() if not (f_min_size <= fsize <= f_max_size): - raise Exception(f"offset {i} is invalid, size out of bounds: " - f"{foffset}, next {next_offset}, implied size: {fsize}, " - f"size bounds: [{f_min_size}, {f_max_size}]") + raise Exception(f'Offset {i} is invalid, size out of bounds: ' + f'{foffset}, next {next_offset}, implied size: {fsize}, ' + f'size bounds: [{f_min_size}, {f_max_size}]') field_values[fkey] = ftyp.deserialize(stream, fsize) - for findex in range(max_findex + 1, cls.N): - if active_fields.get(findex): - raise Exception(f"unknown field index {findex}") return cls(**field_values) # type: ignore def serialize(self, stream: BinaryIO) -> int: active_fields = self.active_fields() num_prefix_bytes = active_fields.serialize(stream) - num_data_bytes = sum( - ftyp.type_byte_length() if ftyp.is_fixed_byte_length() else OFFSET_BYTE_LENGTH - for findex, (_, (ftyp, _)) in enumerate(self.__class__.fields().items()) - if active_fields.get(findex)) + num_data_bytes = 0 + has_dyn_fields = False + for (findex, ftyp) in self.__class__._field_indices.values(): + if not active_fields.get(findex): + continue + if ftyp.is_fixed_byte_length(): + num_data_bytes += ftyp.type_byte_length() + else: + num_data_bytes += OFFSET_BYTE_LENGTH + has_dyn_fields = True - temp_dyn_stream = io.BytesIO() + if has_dyn_fields: + temp_dyn_stream = io.BytesIO() data = super().get_backing().get_left() - for findex, (_, (ftyp, _)) in enumerate(self.__class__.fields().items()): + for (findex, ftyp) in self.__class__._field_indices.values(): if not active_fields.get(findex): continue fnode = data.getter(2**get_depth(self.__class__.N) + findex) @@ -262,8 +278,9 @@ def serialize(self, stream: BinaryIO) -> int: else: encode_offset(stream, num_data_bytes) num_data_bytes += v.serialize(temp_dyn_stream) # type: ignore - temp_dyn_stream.seek(0) - stream.write(temp_dyn_stream.read(num_data_bytes)) + if has_dyn_fields: + temp_dyn_stream.seek(0) + stream.write(temp_dyn_stream.read()) return num_prefix_bytes + num_data_bytes @@ -271,16 +288,14 @@ def serialize(self, stream: BinaryIO) -> int: def navigate_type(cls, key: Any) -> Type[View]: if key == '__active_fields__': return Bitvector[cls.N] - (_, ftyp, fopt) = cls._field_indices[key] - if fopt: - return Optional[ftyp] - return ftyp + (_, ftyp) = cls._field_indices[key] + return Optional[ftyp] @classmethod def key_to_static_gindex(cls, key: Any) -> Gindex: if key == '__active_fields__': return RIGHT_GINDEX - (findex, _, _) = cls._field_indices[key] + (findex, _) = cls._field_indices[key] return 2**get_depth(cls.N) * 2 + findex @@ -378,7 +393,7 @@ def optional_fields(self) -> Bitvector: oindex = 0 for fkey, (_, fopt) in self.__class__.fields().items(): if fopt: - (findex, _, _) = self.__class__.B._field_indices[fkey] + (findex, _) = self.__class__.B._field_indices[fkey] optional_fields.set(oindex, active_fields.get(findex)) oindex += 1 return optional_fields @@ -397,7 +412,7 @@ def __getattr__(self, item): except KeyError: raise AttributeError(f"unknown attribute {item}") try: - (findex, _, _) = self.__class__.B._field_indices[item] + (findex, _) = self.__class__.B._field_indices[item] except KeyError: raise AttributeError(f"unknown attribute {item} in base") @@ -421,7 +436,7 @@ def __setattr__(self, key, value): except KeyError: raise AttributeError(f"unknown attribute {key}") try: - (findex, _, _) = self.__class__.B._field_indices[key] + (findex, _) = self.__class__.B._field_indices[key] except KeyError: raise AttributeError(f"unknown attribute {key} in base") @@ -535,7 +550,7 @@ def serialize(self, stream: BinaryIO) -> int: n = len(self.__class__.B.fields()) for fkey, (ftyp, _) in self.__class__.fields().items(): if issubclass(self.__class__.B, StableContainer): - (findex, _, _) = self.__class__.B._field_indices[fkey] + (findex, _) = self.__class__.B._field_indices[fkey] if not active_fields.get(findex): continue fnode = data.getter(2**get_depth(n) + findex) @@ -568,7 +583,7 @@ def key_to_static_gindex(cls, key: Any) -> Gindex: return RIGHT_GINDEX (_, _) = cls.fields()[key] if issubclass(cls.B, StableContainer): - (findex, _, _) = cls.B._field_indices[key] + (findex, _) = cls.B._field_indices[key] return 2**get_depth(cls.B.N) * 2 + findex else: findex = cls.B._field_indices[key] diff --git a/remerkleable/test_impl.py b/remerkleable/test_impl.py index 5600acf..d7dbb95 100644 --- a/remerkleable/test_impl.py +++ b/remerkleable/test_impl.py @@ -482,7 +482,7 @@ def test_stable_container(): # Defines the common merkleization format and a portable serialization format class Shape(StableContainer[4]): side: Optional[uint16] - color: uint8 + color: Optional[uint8] radius: Optional[uint16] # Inherits merkleization format from `Shape`, but is serialized more compactly @@ -844,13 +844,13 @@ class ShapeContainerRepr(Container): # basic container class Shape1(StableContainer[4]): side: Optional[uint16] - color: uint8 + color: Optional[uint8] radius: Optional[uint16] # basic container with different depth class Shape2(StableContainer[8]): side: Optional[uint16] - color: uint8 + color: Optional[uint8] radius: Optional[uint16] # basic container with variable fields From 1d3bb9ca03dd68b07e93b8e75e6fbb0897ddf123 Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Sat, 1 Jun 2024 23:22:29 +0200 Subject: [PATCH 17/23] Fix lint --- remerkleable/stable_container.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index 32054eb..6aa2893 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -75,7 +75,7 @@ def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None active_fields.set(findex, True) input_nodes.append(fnode) if len(kwargs) > 0: - raise AttributeError(f'Fields [{''.join(kwargs.keys())}] unknown in `{cls.__name__}`') + raise AttributeError(f'Fields [{"".join(kwargs.keys())}] unknown in `{cls.__name__}`') backing = PairNode( left=subtree_fill_to_contents(input_nodes, get_depth(cls.N)), From c3e33d8100dd9802e1cf5e63afafbbb95136a946 Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Sat, 1 Jun 2024 23:27:48 +0200 Subject: [PATCH 18/23] Fix `fields` accessor --- remerkleable/stable_container.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index 6aa2893..a0ceaf6 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -128,7 +128,7 @@ def __init_subclass__(cls, *args, **kwargs): @classmethod def fields(cls) -> Dict[str, Type[View]]: - { fkey: ftyp for fkey, (_, ftyp) in cls._field_indices } + { fkey: ftyp for fkey, (_, ftyp) in cls._field_indices.items() } @classmethod def is_fixed_byte_length(cls) -> bool: From 56338cf697df457f4ab883b79bdcf96f9049fac7 Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Sat, 1 Jun 2024 23:36:03 +0200 Subject: [PATCH 19/23] Add missing `return` --- remerkleable/stable_container.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index a0ceaf6..2fa1bd3 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -128,7 +128,7 @@ def __init_subclass__(cls, *args, **kwargs): @classmethod def fields(cls) -> Dict[str, Type[View]]: - { fkey: ftyp for fkey, (_, ftyp) in cls._field_indices.items() } + return { fkey: ftyp for fkey, (_, ftyp) in cls._field_indices.items() } @classmethod def is_fixed_byte_length(cls) -> bool: From 22db96ba93aabdf57dd649c24f57f5e906d5b4d5 Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Tue, 4 Jun 2024 12:14:33 +0200 Subject: [PATCH 20/23] Cleanup `Profile` implementation --- remerkleable/stable_container.py | 402 ++++++++++++++++--------------- remerkleable/test_impl.py | 62 +---- 2 files changed, 209 insertions(+), 255 deletions(-) diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index 2fa1bd3..8172ece 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -15,17 +15,37 @@ RIGHT_GINDEX N = TypeVar('N', bound=int) -B = TypeVar('B', bound='ComplexView') -S = TypeVar('S', bound='ComplexView') +SV = TypeVar('SV', bound='ComplexView') +BV = TypeVar('BV', bound='ComplexView') -def all_fields(cls) -> Dict[str, Tuple[Type[View], bool]]: - fields = {} - for k, v in cls.__annotations__.items(): - fopt = get_origin(v) == PyUnion and type(None) in get_args(v) - ftyp = get_args(v)[0] if fopt else v - fields[k] = (ftyp, fopt) - return fields +def stable_get(self, findex, ftyp, n): + if not self.active_fields().get(findex): + return None + data = self.get_backing().get_left() + fnode = data.getter(2**get_depth(n) + findex) + return ftyp.view_from_backing(fnode) + + +def stable_set(self, findex, ftyp, n, value): + next_backing = self.get_backing() + + active_fields = self.active_fields() + active_fields.set(findex, value is not None) + next_backing = next_backing.rebind_right(active_fields.get_backing()) + + if value is not None: + if isinstance(value, ftyp): + fnode = value.get_backing() + else: + fnode = ftyp.coerce_view(value).get_backing() + else: + fnode = zero_node(0) + data = next_backing.get_left() + next_data = data.setter(2**get_depth(n) + findex)(fnode) + next_backing = next_backing.rebind_left(next_data) + + self.set_backing(next_backing) def field_val_repr(self, fkey: str, ftyp: Type[View], fopt: bool) -> str: @@ -42,12 +62,6 @@ def field_val_repr(self, fkey: str, ftyp: Type[View], fopt: bool) -> str: return f'{field_start} *omitted*' -def repr(self) -> str: - return f'{self.__class__.type_repr()}:\n' + '\n'.join( - indent(field_val_repr(self, fkey, ftyp, fopt), ' ') - for fkey, (ftyp, fopt) in self.__class__.fields().items()) - - class StableContainer(ComplexView): __slots__ = '_field_indices', 'N' _field_indices: Dict[str, Tuple[int, Type[View]]] @@ -99,23 +113,21 @@ def __new__(cls, name, bases, dct): return super().__new__(cls, name, bases, dct, n=n) class StableContainerView(StableContainer, metaclass=StableContainerMeta): - def __init_subclass__(cls, *args, **kwargs): + def __init_subclass__(cls, **kwargs): if 'N' in cls.__dict__: raise TypeError(f'Cannot override `N` inside `{cls.__name__}`') cls._field_indices = {} - for findex, (fkey, v) in enumerate(cls.__annotations__.items()): - if fkey in cls._field_indices: - raise TypeError(f'Field `{fkey}` defined multiple times in `{cls.__name}`') + for findex, (fkey, t) in enumerate(cls.__annotations__.items()): if ( - get_origin(v) != PyUnion - or len(get_args(v)) != 2 - or type(None) not in get_args(v) + get_origin(t) != PyUnion + or len(get_args(t)) != 2 + or type(None) not in get_args(t) ): raise TypeError( f'`StableContainer` fields must be `Optional[T]` ' - f'but `{cls.__name__}.{fkey}` has type `{v.__name__}`' + f'but `{cls.__name__}.{fkey}` has type `{t.__name__}`' ) - ftyp = get_args(v)[0] if get_args(v)[0] is not type(None) else get_args(v)[1] + ftyp = get_args(t)[0] if get_args(t)[0] is not type(None) else get_args(t)[1] cls._field_indices[fkey] = (findex, ftyp) if len(cls._field_indices) > cls.N: raise TypeError( @@ -147,6 +159,18 @@ def max_byte_length(cls) -> int: total += ftyp.max_byte_length() return total + @classmethod + def is_packed(cls) -> bool: + return False + + @classmethod + def tree_depth(cls) -> int: + return get_depth(cls.N) + + @classmethod + def item_elem_cls(cls, i: int) -> Type[View]: + return list(cls._field_indices.values())[i] + def active_fields(self) -> Bitvector: active_fields_node = super().get_backing().get_right() return Bitvector[self.__class__.N].view_from_backing(active_fields_node) @@ -165,11 +189,7 @@ def __getattr__(self, item): except KeyError: raise AttributeError(f'Unknown field `{item}`') - if not self.active_fields().get(findex): - return None - data = super().get_backing().get_left() - fnode = data.getter(2**get_depth(self.__class__.N) + findex) - return ftyp.view_from_backing(fnode) + return stable_get(self, findex, ftyp, self.__class__.N) def __setattr__(self, key, value): if key[0] == '_': @@ -180,37 +200,22 @@ def __setattr__(self, key, value): except KeyError: raise AttributeError(f'Unknown field `{key}`') - next_backing = self.get_backing() - - active_fields = self.active_fields() - active_fields.set(findex, value is not None) - next_backing = next_backing.rebind_right(active_fields.get_backing()) - - if value is not None: - if isinstance(value, ftyp): - fnode = value.get_backing() - else: - fnode = ftyp.coerce_view(value).get_backing() - else: - fnode = zero_node(0) - data = next_backing.get_left() - next_data = data.setter(2**get_depth(self.__class__.N) + findex)(fnode) - next_backing = next_backing.rebind_left(next_data) - - self.set_backing(next_backing) + stable_set(self, findex, ftyp, self.__class__.N, value) def __repr__(self): - return repr(self) + return f'{self.__class__.type_repr()}:\n' + '\n'.join( + indent(field_val_repr(self, fkey, ftyp, fopt=True), ' ') + for fkey, (_, ftyp) in self.__class__._field_indices.items()) @classmethod def type_repr(cls) -> str: return f'StableContainer[{cls.N}]' @classmethod - def deserialize(cls: Type[S], stream: BinaryIO, scope: int) -> S: + def deserialize(cls: Type[SV], stream: BinaryIO, scope: int) -> SV: num_prefix_bytes = Bitvector[cls.N].type_byte_length() if scope < num_prefix_bytes: - raise ValueError(f'Scope too small, cannot read `StableContainer[{cls.N}]` active fields') + raise ValueError(f'Scope too small for `StableContainer[{cls.N}]` active fields') active_fields = Bitvector[cls.N].deserialize(stream, num_prefix_bytes) scope = scope - num_prefix_bytes @@ -300,55 +305,108 @@ def key_to_static_gindex(cls, key: Any) -> Gindex: class Profile(ComplexView): + __slots__ = '_field_indices', '_o', 'B' + _field_indices: Dict[str, Tuple[int, Type[View], bool]] _o: int + B: PyUnion[Type[StableContainer], Type[Container]] def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None, **kwargs): if backing is not None: if len(kwargs) != 0: - raise Exception("cannot have both a backing and elements to init fields") + raise Exception('Cannot have both a backing and elements to init fields') return super().__new__(cls, backing=backing, hook=hook, **kwargs) - extra_kwargs = kwargs.copy() - for fkey, (ftyp, fopt) in cls.fields().items(): - if fkey in extra_kwargs: - extra_kwargs.pop(fkey) + extra_kw = kwargs.copy() + for fkey, (_, _, fopt) in cls._field_indices.items(): + if fkey in extra_kw: + extra_kw.pop(fkey) elif not fopt: - raise AttributeError(f"Field '{fkey}' is required in {cls}") + raise AttributeError(f'Field `{fkey}` is required in {cls.__name__}') else: pass - if len(extra_kwargs) > 0: - raise AttributeError(f'The field names [{"".join(extra_kwargs.keys())}] are not defined in {cls}') + if len(extra_kw) > 0: + raise AttributeError(f'Fields [{"".join(extra_kw.keys())}] unknown in `{cls.__name__}`') value = cls.B(backing, hook, **kwargs) return cls(backing=value.get_backing()) - def __init_subclass__(cls, *args, **kwargs): - super().__init_subclass__(*args, **kwargs) - cls._o = 0 - for _, (_, fopt) in cls.fields().items(): - if fopt: - cls._o += 1 - assert cls._o == 0 or issubclass(cls.B, StableContainer) - - def __class_getitem__(cls, b) -> Type["Profile"]: + def __init_subclass__(cls, **kwargs): + if 'b' not in kwargs: + raise TypeError(f'Missing base type: `{cls.__name__}(Profile)`') + b = kwargs.pop('b') if not issubclass(b, StableContainer) and not issubclass(b, Container): - raise Exception(f"invalid Profile base: {b}") + raise TypeError(f'Invalid base type: `{cls.__name__}(Profile[{b.__name__}])`') + cls.B = b - class ProfileView(Profile): - B = b + def __class_getitem__(cls, b) -> Type['Profile']: + class ProfileMeta(ViewMeta): + def __new__(cls, name, bases, dct): + return super().__new__(cls, name, bases, dct, b=b) + + class ProfileView(Profile, metaclass=ProfileMeta): + def __init_subclass__(cls, **kwargs): + if 'B' in cls.__dict__: + raise TypeError(f'Cannot override `B` inside `{cls.__name__}`') + cls._field_indices = {} + cls._o = 0 + for (fkey, t) in cls.__annotations__.items(): + if fkey not in cls.B._field_indices: + raise TypeError( + f'`{cls.__name__}` fields must exist in the base type ' + f'but `{fkey}` is not defined in `{cls.B.__name__}`' + ) + if issubclass(cls.B, StableContainer): + (findex, ftyp) = cls.B._field_indices[fkey] + else: + findex = cls.B._field_indices[fkey] + ftyp = cls.B.fields()[fkey] + fopt = ( + get_origin(t) == PyUnion + and len(get_args(t)) == 2 + and type(None) in get_args(t) + ) + if fopt: + if not issubclass(cls.B, StableContainer): + raise TypeError( + f'`{cls.__name__}.{fkey}` cannot be `Optional[T]` ' + f'as base type `{cls.B.__name__}` is not a `StableContainer`' + ) + t = get_args(t)[0] if get_args(t)[0] is not type(None) else get_args(t)[1] + if t == ftyp: + pass + elif issubclass(t, Profile) and t.B == ftyp: + pass + else: + raise TypeError( + f'`{cls.__name__}.{fkey}` has type `{t.__name__}`, incompatible ' + f'with base field `{cls.B.__name__}.{fkey}` of type `{ftyp.__name__}`' + ) + cls._field_indices[fkey] = (findex, t, fopt) + if fopt: + cls._o += 1 + if ( + not issubclass(cls.B, StableContainer) + and len(cls._field_indices) != len(cls.B._field_indices) + ): + for fkey, (findex, ftyp) in cls.B._field_indices.items(): + if fkey not in cls._field_indices: + raise TypeError( + f'`{cls.__name__}.{fkey}` of type `{ftyp.__name__}` is required ' + f'as base type `{cls.B.__name__}` is not a `StbleContainer`' + ) ProfileView.__name__ = ProfileView.type_repr() return ProfileView @classmethod def fields(cls) -> Dict[str, Tuple[Type[View], bool]]: - return all_fields(cls) + return { fkey: (ftyp, fopt) for fkey, (_, ftyp, fopt) in cls._field_indices.items() } @classmethod def is_fixed_byte_length(cls) -> bool: if cls._o > 0: return False - for _, (ftyp, _) in cls.fields().items(): + for (_, ftyp, _) in cls._field_indices.values(): if not ftyp.is_fixed_byte_length(): return False return True @@ -358,12 +416,12 @@ def type_byte_length(cls) -> int: if cls.is_fixed_byte_length(): return cls.min_byte_length() else: - raise Exception("dynamic length Profile does not have a fixed byte length") + raise Exception(f'Dynamic length `Profile` does not have a fixed byte length') @classmethod def min_byte_length(cls) -> int: total = Bitvector[cls._o].type_byte_length() if cls._o > 0 else 0 - for _, (ftyp, fopt) in cls.fields().items(): + for (_, ftyp, fopt) in cls._field_indices.values(): if fopt: continue if not ftyp.is_fixed_byte_length(): @@ -374,33 +432,47 @@ def min_byte_length(cls) -> int: @classmethod def max_byte_length(cls) -> int: total = Bitvector[cls._o].type_byte_length() if cls._o > 0 else 0 - for _, (ftyp, _) in cls.fields().items(): + for (_, ftyp, _) in cls._field_indices.values(): if not ftyp.is_fixed_byte_length(): total += OFFSET_BYTE_LENGTH total += ftyp.max_byte_length() return total + @classmethod + def is_packed(cls) -> bool: + return False + + @classmethod + def tree_depth(cls) -> int: + return cls.B.tree_depth() + + @classmethod + def item_elem_cls(cls, i: int) -> Type[View]: + return cls.B.item_elem_cls(i) + def active_fields(self) -> Bitvector: - assert issubclass(self.__class__.B, StableContainer) + if not issubclass(self.__class__.B, StableContainer): + raise Exception(f'`active_fields` requires `Profile` with `StableContainer` base') active_fields_node = super().get_backing().get_right() return Bitvector[self.__class__.B.N].view_from_backing(active_fields_node) def optional_fields(self) -> Bitvector: - assert issubclass(self.__class__.B, StableContainer) - assert self.__class__._o > 0 + if not issubclass(self.__class__.B, StableContainer): + raise Exception(f'`optional_fields` requires `Profile` with `StableContainer` base') + if self.__class__._o == 0: + raise Exception(f'`{self.__class__.__name__}` does not have any `Optional[T]` fields') active_fields = self.active_fields() optional_fields = Bitvector[self.__class__._o]() oindex = 0 - for fkey, (_, fopt) in self.__class__.fields().items(): + for (findex, _, fopt) in self.__class__._field_indices.values(): if fopt: - (findex, _) = self.__class__.B._field_indices[fkey] optional_fields.set(oindex, active_fields.get(findex)) oindex += 1 return optional_fields def __getattribute__(self, item): if item == 'B': - raise AttributeError(f"use .__class__.{item} to access {item}") + raise AttributeError(f'Use `.__class__.{item}` to access `{item}`') return object.__getattribute__(self, item) def __getattr__(self, item): @@ -408,75 +480,51 @@ def __getattr__(self, item): return super().__getattribute__(item) else: try: - (ftyp, fopt) = self.__class__.fields()[item] + (findex, ftyp, fopt) = self.__class__._field_indices[item] except KeyError: - raise AttributeError(f"unknown attribute {item}") - try: - (findex, _) = self.__class__.B._field_indices[item] - except KeyError: - raise AttributeError(f"unknown attribute {item} in base") - - if not issubclass(self.__class__.B, StableContainer): - return super().get(findex) - - if not self.active_fields().get(findex): - assert fopt - return None + raise AttributeError(f'Unknown field `{item}`') - data = super().get_backing().get_left() - fnode = data.getter(2**get_depth(self.__class__.B.N) + findex) - return ftyp.view_from_backing(fnode) + if issubclass(self.__class__.B, StableContainer): + value = stable_get(self, findex, ftyp, self.__class__.B.N) + else: + value = super().get(findex) + if not isinstance(value, ftyp): + value = ftyp(backing=value.get_backing()) + assert value is not None or fopt + return value def __setattr__(self, key, value): if key[0] == '_': super().__setattr__(key, value) else: try: - (ftyp, fopt) = self.__class__.fields()[key] + (findex, ftyp, fopt) = self.__class__._field_indices[key] except KeyError: - raise AttributeError(f"unknown attribute {key}") - try: - (findex, _) = self.__class__.B._field_indices[key] - except KeyError: - raise AttributeError(f"unknown attribute {key} in base") - - if not issubclass(self.__class__.B, StableContainer): - super().set(findex, value) - return - - next_backing = self.get_backing() + raise AttributeError(f'Unknown field `{key}`') - assert value is not None or fopt - active_fields = self.active_fields() - active_fields.set(findex, value is not None) - next_backing = next_backing.rebind_right(active_fields.get_backing()) + if value is None and not fopt: + raise ValueError(f'Field `{key}` is required and cannot be set to `None`') - if value is not None: - if isinstance(value, ftyp): - fnode = value.get_backing() - else: - fnode = ftyp.coerce_view(value).get_backing() + if issubclass(self.__class__.B, StableContainer): + stable_set(self, findex, ftyp, self.__class__.B.N, value) else: - fnode = zero_node(0) - data = next_backing.get_left() - next_data = data.setter(2**get_depth(self.__class__.B.N) + findex)(fnode) - next_backing = next_backing.rebind_left(next_data) - - self.set_backing(next_backing) + super().set(findex, value) def __repr__(self): - return repr(self) + return f'{self.__class__.type_repr()}:\n' + '\n'.join( + indent(field_val_repr(self, fkey, ftyp, fopt), ' ') + for fkey, (_, ftyp, fopt) in self.__class__._field_indices.items()) @classmethod def type_repr(cls) -> str: - return f"Profile[{cls.B.__name__}]" + return f'Profile[{cls.B.__name__}]' @classmethod - def deserialize(cls: Type[B], stream: BinaryIO, scope: int) -> B: + def deserialize(cls: Type[BV], stream: BinaryIO, scope: int) -> BV: if cls._o > 0: num_prefix_bytes = Bitvector[cls._o].type_byte_length() if scope < num_prefix_bytes: - raise ValueError("scope too small, cannot read Profile optional fields") + raise ValueError(f'Scope too small for `Profile[{cls.B.__name__}]` optional fields') optional_fields = Bitvector[cls._o].deserialize(stream, num_prefix_bytes) scope = scope - num_prefix_bytes @@ -484,11 +532,11 @@ def deserialize(cls: Type[B], stream: BinaryIO, scope: int) -> B: dyn_fields: PyList[FieldOffset] = [] fixed_size = 0 oindex = 0 - for fkey, (ftyp, fopt) in cls.fields().items(): + for fkey, (_, ftyp, fopt) in cls._field_indices.items(): if fopt: - have_field = optional_fields.get(oindex) + has_field = optional_fields.get(oindex) oindex += 1 - if not have_field: + if not has_field: field_values[fkey] = None continue if ftyp.is_fixed_byte_length(): @@ -502,21 +550,20 @@ def deserialize(cls: Type[B], stream: BinaryIO, scope: int) -> B: assert oindex == cls._o if len(dyn_fields) > 0: if dyn_fields[0].offset < fixed_size: - raise Exception(f"first offset {dyn_fields[0].offset} is " - f"smaller than expected fixed size {fixed_size}") + raise Exception(f'First offset {dyn_fields[0].offset} is ' + f'smaller than expected fixed size {fixed_size}') for i, (fkey, ftyp, foffset) in enumerate(dyn_fields): next_offset = dyn_fields[i + 1].offset if i + 1 < len(dyn_fields) else scope if foffset > next_offset: - raise Exception(f"offset {i} is invalid: {foffset} " - f"larger than next offset {next_offset}") + raise Exception(f'Offset {i} is invalid: {foffset} ' + f'larger than next offset {next_offset}') fsize = next_offset - foffset f_min_size, f_max_size = ftyp.min_byte_length(), ftyp.max_byte_length() if not (f_min_size <= fsize <= f_max_size): - raise Exception(f"offset {i} is invalid, size out of bounds: " - f"{foffset}, next {next_offset}, implied size: {fsize}, " - f"size bounds: [{f_min_size}, {f_max_size}]") + raise Exception(f'Offset {i} is invalid, size out of bounds: ' + f'{foffset}, next {next_offset}, implied size: {fsize}, ' + f'size bounds: [{f_min_size}, {f_max_size}]') field_values[fkey] = ftyp.deserialize(stream, fsize) - return cls(**field_values) # type: ignore def serialize(self, stream: BinaryIO) -> int: @@ -527,20 +574,23 @@ def serialize(self, stream: BinaryIO) -> int: num_prefix_bytes = 0 num_data_bytes = 0 + has_dyn_fields = False oindex = 0 - for _, (ftyp, fopt) in self.__class__.fields().items(): + for (_, ftyp, fopt) in self.__class__._field_indices.values(): if fopt: - have_field = optional_fields.get(oindex) + has_field = optional_fields.get(oindex) oindex += 1 - if not have_field: + if not has_field: continue if ftyp.is_fixed_byte_length(): num_data_bytes += ftyp.type_byte_length() else: num_data_bytes += OFFSET_BYTE_LENGTH + has_dyn_fields = True assert oindex == self.__class__._o - temp_dyn_stream = io.BytesIO() + if has_dyn_fields: + temp_dyn_stream = io.BytesIO() if issubclass(self.__class__.B, StableContainer): data = super().get_backing().get_left() active_fields = self.active_fields() @@ -548,23 +598,19 @@ def serialize(self, stream: BinaryIO) -> int: else: data = super().get_backing() n = len(self.__class__.B.fields()) - for fkey, (ftyp, _) in self.__class__.fields().items(): - if issubclass(self.__class__.B, StableContainer): - (findex, _) = self.__class__.B._field_indices[fkey] - if not active_fields.get(findex): - continue - fnode = data.getter(2**get_depth(n) + findex) - else: - findex = self.__class__.B._field_indices[fkey] - fnode = data.getter(2**get_depth(n) + findex) + for (findex, ftyp, _) in self.__class__._field_indices.values(): + if issubclass(self.__class__.B, StableContainer) and not active_fields.get(findex): + continue + fnode = data.getter(2**get_depth(n) + findex) v = ftyp.view_from_backing(fnode) if ftyp.is_fixed_byte_length(): v.serialize(stream) else: encode_offset(stream, num_data_bytes) num_data_bytes += v.serialize(temp_dyn_stream) # type: ignore - temp_dyn_stream.seek(0) - stream.write(temp_dyn_stream.read(num_data_bytes)) + if has_dyn_fields: + temp_dyn_stream.seek(0) + stream.write(temp_dyn_stream.read(num_data_bytes)) return num_prefix_bytes + num_data_bytes @@ -572,58 +618,16 @@ def serialize(self, stream: BinaryIO) -> int: def navigate_type(cls, key: Any) -> Type[View]: if key == '__active_fields__': return Bitvector[cls.B.N] - (ftyp, fopt) = cls.fields()[key] - if fopt: - return Optional[ftyp] - return ftyp + (_, ftyp, fopt) = cls._field_indices[key] + return Optional[ftyp] if fopt else ftyp @classmethod def key_to_static_gindex(cls, key: Any) -> Gindex: if key == '__active_fields__': return RIGHT_GINDEX - (_, _) = cls.fields()[key] + (findex, _, _) = cls._field_indices[key] if issubclass(cls.B, StableContainer): - (findex, _) = cls.B._field_indices[key] return 2**get_depth(cls.B.N) * 2 + findex else: - findex = cls.B._field_indices[key] n = len(cls.B.fields()) return 2**get_depth(n) + findex - - -class OneOf(ComplexView): - def __class_getitem__(cls, b) -> Type["OneOf"]: - if not issubclass(b, StableContainer) and not issubclass(b, Container): - raise Exception(f"invalid OneOf base: {b}") - - class OneOfView(OneOf, b): - B = b - - @classmethod - def fields(cls): - return b.fields() - - OneOfView.__name__ = OneOfView.type_repr() - return OneOfView - - def __repr__(self): - return repr(self) - - @classmethod - def type_repr(cls) -> str: - return f"OneOf[{cls.B}]" - - @classmethod - def decode_bytes(cls: Type[B], bytez: bytes, *args, **kwargs) -> B: - stream = io.BytesIO() - stream.write(bytez) - stream.seek(0) - return cls.deserialize(stream, len(bytez), *args, **kwargs) - - @classmethod - def deserialize(cls: Type[B], stream: BinaryIO, scope: int, *args, **kwargs) -> B: - value = cls.B.deserialize(stream, scope) - v = cls.select_from_base(value, *args, **kwargs) - if not issubclass(v.B, cls.B): - raise Exception(f"unsupported select_from_base result: {v}") - return v(backing=value.get_backing()) diff --git a/remerkleable/test_impl.py b/remerkleable/test_impl.py index d7dbb95..61fe485 100644 --- a/remerkleable/test_impl.py +++ b/remerkleable/test_impl.py @@ -9,7 +9,7 @@ from remerkleable.bitfields import Bitvector, Bitlist from remerkleable.byte_arrays import ByteVector, ByteList from remerkleable.core import View, ObjType -from remerkleable.stable_container import OneOf, Profile, StableContainer +from remerkleable.stable_container import Profile, StableContainer from remerkleable.union import Union from hashlib import sha256 @@ -495,16 +495,6 @@ class Circle(Profile[Shape]): color: uint8 radius: uint16 - class AnyShape(OneOf[Shape]): - @classmethod - def select_from_base(cls, value: Shape, circle_allowed = False) -> Type[Shape]: - if value.radius is not None: - assert circle_allowed - return Circle - if value.side is not None: - return Square - assert False - # Defines a container with immutable scheme that contains two `StableContainer` class ShapePair(Container): shape_1: Shape @@ -533,18 +523,6 @@ class ShapePairRepr(Container): shape_1: ShapeRepr shape_2: ShapeRepr - class AnyShapePair(OneOf[ShapePair]): - @classmethod - def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[ShapePair]: - typ_1 = AnyShape.select_from_base(value.shape_1, circle_allowed) - typ_2 = AnyShape.select_from_base(value.shape_2, circle_allowed) - assert typ_1 == typ_2 - if typ_1 is Circle: - return CirclePair - if typ_1 is Square: - return SquarePair - assert False - # Square tests square_bytes_stable = bytes.fromhex("03420001") square_bytes_profile = bytes.fromhex("420001") @@ -564,9 +542,7 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert all(square.encode_bytes() == square_bytes_profile for square in squares) assert ( Square(backing=Shape.decode_bytes(square_bytes_stable).get_backing()) == - Square.decode_bytes(square_bytes_profile) == - AnyShape.decode_bytes(square_bytes_stable) == - AnyShape.decode_bytes(square_bytes_stable, circle_allowed = True) + Square.decode_bytes(square_bytes_profile) ) assert all(shape.hash_tree_root() == square_root for shape in shapes) assert all(square.hash_tree_root() == square_root for square in squares) @@ -603,9 +579,7 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert all(square.encode_bytes() == square_bytes_profile for square in squares) assert ( Square(backing=Shape.decode_bytes(square_bytes_stable).get_backing()) == - Square.decode_bytes(square_bytes_profile) == - AnyShape.decode_bytes(square_bytes_stable) == - AnyShape.decode_bytes(square_bytes_stable, circle_allowed = True) + Square.decode_bytes(square_bytes_profile) ) assert all(shape.hash_tree_root() == square_root for shape in shapes) assert all(square.hash_tree_root() == square_root for square in squares) @@ -644,8 +618,7 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert all(circle.encode_bytes() == circle_bytes_profile for circle in circles) assert ( Circle(backing=Shape.decode_bytes(circle_bytes_stable).get_backing()) == - Circle.decode_bytes(circle_bytes_profile) == - AnyShape.decode_bytes(circle_bytes_stable, circle_allowed = True) + Circle.decode_bytes(circle_bytes_profile) ) assert all(shape.hash_tree_root() == circle_root for shape in shapes) assert all(circle.hash_tree_root() == circle_root for circle in circles) @@ -666,11 +639,6 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert False except: pass - try: - circle = AnyShape.decode_bytes(circle_bytes_stable, circle_allowed = False) - assert False - except: - pass # SquarePair tests square_pair_bytes_stable = bytes.fromhex("080000000c0000000342000103690001") @@ -703,9 +671,7 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert all(pair.encode_bytes() == square_pair_bytes_profile for pair in square_pairs) assert ( SquarePair(backing=ShapePair.decode_bytes(square_pair_bytes_stable).get_backing()) == - SquarePair.decode_bytes(square_pair_bytes_profile) == - AnyShapePair.decode_bytes(square_pair_bytes_stable) == - AnyShapePair.decode_bytes(square_pair_bytes_stable, circle_allowed = True) + SquarePair.decode_bytes(square_pair_bytes_profile) ) assert all(pair.hash_tree_root() == square_pair_root for pair in shape_pairs) assert all(pair.hash_tree_root() == square_pair_root for pair in square_pairs) @@ -741,8 +707,7 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert all(pair.encode_bytes() == circle_pair_bytes_profile for pair in circle_pairs) assert ( CirclePair(backing=ShapePair.decode_bytes(circle_pair_bytes_stable).get_backing()) == - CirclePair.decode_bytes(circle_pair_bytes_profile) == - AnyShapePair.decode_bytes(circle_pair_bytes_stable, circle_allowed = True) + CirclePair.decode_bytes(circle_pair_bytes_profile) ) assert all(pair.hash_tree_root() == circle_pair_root for pair in shape_pairs) assert all(pair.hash_tree_root() == circle_pair_root for pair in circle_pairs) @@ -762,11 +727,6 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert False except: pass - try: - shape = AnyShape.decode_bytes(shape_bytes) - assert False - except: - pass shape = Shape(side=0x42, color=1, radius=0x42) shape_bytes = bytes.fromhex("074200014200") assert shape.encode_bytes() == shape_bytes @@ -781,16 +741,6 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert False except: pass - try: - shape = AnyShape.decode_bytes(shape_bytes) - assert False - except: - pass - try: - shape = AnyShape.decode_bytes("00") - assert False - except: - pass try: shape = Shape.decode_bytes("00") assert False From b374d35504406035d8a0dd0295c2c6ec035378bb Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Tue, 4 Jun 2024 12:45:53 +0200 Subject: [PATCH 21/23] Disallow field reordering in `Profile` --- remerkleable/stable_container.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index 8172ece..8467c8d 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -349,6 +349,7 @@ def __init_subclass__(cls, **kwargs): raise TypeError(f'Cannot override `B` inside `{cls.__name__}`') cls._field_indices = {} cls._o = 0 + last_findex = -1 for (fkey, t) in cls.__annotations__.items(): if fkey not in cls.B._field_indices: raise TypeError( @@ -360,6 +361,12 @@ def __init_subclass__(cls, **kwargs): else: findex = cls.B._field_indices[fkey] ftyp = cls.B.fields()[fkey] + if findex <= last_findex: + raise TypeError( + f'`{cls.__name__}` fields must have the same order as in the base type ' + f'but `{fkey}` is defined earlier than in `{cls.B.__name__}`' + ) + last_findex = findex fopt = ( get_origin(t) == PyUnion and len(get_args(t)) == 2 From b833f77a0326f5f6b105d9bb225e67c37f928b0c Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Tue, 4 Jun 2024 20:25:48 +0200 Subject: [PATCH 22/23] Allow more complex composition when using `Profile` --- remerkleable/stable_container.py | 191 +++++++++++++++++++++---------- remerkleable/test_impl.py | 11 +- 2 files changed, 134 insertions(+), 68 deletions(-) diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index 8467c8d..44d9ba4 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -6,8 +6,10 @@ TypeVar, Type, Union as PyUnion, \ get_args, get_origin from textwrap import indent -from remerkleable.bitfields import Bitvector -from remerkleable.complex import ComplexView, Container, FieldOffset, \ +from remerkleable.basic import boolean, uint8, uint16, uint32, uint64, uint128, uint256 +from remerkleable.bitfields import Bitlist, Bitvector +from remerkleable.byte_arrays import ByteList, ByteVector +from remerkleable.complex import ComplexView, Container, FieldOffset, List, Vector, \ decode_offset, encode_offset from remerkleable.core import View, ViewHook, ViewMeta, OFFSET_BYTE_LENGTH from remerkleable.tree import Gindex, NavigationError, Node, PairNode, \ @@ -102,9 +104,9 @@ def __init_subclass__(cls, **kwargs): raise TypeError(f'Missing capacity: `{cls.__name__}(StableContainer)`') n = kwargs.pop('n') if not isinstance(n, int): - raise TypeError(f'Invalid capacity: `{cls.__name__}(StableContainer[{n}])`') + raise TypeError(f'Invalid capacity: `StableContainer[{n}]`') if n <= 0: - raise TypeError(f'Unsupported capacity: `{cls.__name__}(StableContainer[{n}])`') + raise TypeError(f'Unsupported capacity: `StableContainer[{n}]`') cls.N = n def __class_getitem__(cls, n: int) -> Type['StableContainer']: @@ -308,7 +310,7 @@ class Profile(ComplexView): __slots__ = '_field_indices', '_o', 'B' _field_indices: Dict[str, Tuple[int, Type[View], bool]] _o: int - B: PyUnion[Type[StableContainer], Type[Container]] + B: Type[StableContainer] def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None, **kwargs): if backing is not None: @@ -334,11 +336,124 @@ def __init_subclass__(cls, **kwargs): if 'b' not in kwargs: raise TypeError(f'Missing base type: `{cls.__name__}(Profile)`') b = kwargs.pop('b') - if not issubclass(b, StableContainer) and not issubclass(b, Container): - raise TypeError(f'Invalid base type: `{cls.__name__}(Profile[{b.__name__}])`') + if not issubclass(b, StableContainer): + raise TypeError(f'Invalid base type: `Profile[{b.__name__}]`') cls.B = b def __class_getitem__(cls, b) -> Type['Profile']: + def has_compatible_merkleization(ftyp, ftyp_base) -> bool: + if ftyp == ftyp_base: + return True + if issubclass(ftyp, boolean): + return issubclass(ftyp_base, boolean) + if issubclass(ftyp, uint8): + return issubclass(ftyp_base, uint8) + if issubclass(ftyp, uint16): + return issubclass(ftyp_base, uint16) + if issubclass(ftyp, uint32): + return issubclass(ftyp_base, uint32) + if issubclass(ftyp, uint64): + return issubclass(ftyp_base, uint64) + if issubclass(ftyp, uint128): + return issubclass(ftyp_base, uint128) + if issubclass(ftyp, uint256): + return issubclass(ftyp_base, uint256) + if issubclass(ftyp, Bitlist): + return ( + issubclass(ftyp_base, Bitlist) + and ftyp.limit() == ftyp_base.limit() + ) + if issubclass(ftyp, Bitvector): + return ( + issubclass(ftyp_base, Bitvector) + and ftyp.vector_length() == ftyp_base.vector_length() + ) + if issubclass(ftyp, ByteList): + if issubclass(ftyp_base, ByteList): + return ftyp.limit() == ftyp_base.limit() + return ( + issubclass(ftyp_base, List) + and ftyp.limit() == ftyp_base.limit() + and issubclass(ftyp_base.element_cls(), uint8) + ) + if issubclass(ftyp, ByteVector): + if issubclass(ftyp_base, ByteVector): + return ftyp.vector_length() == ftyp_base.vector_length() + return ( + issubclass(ftyp_base, Vector) + and ftyp.vector_length() == ftyp_base.vector_length() + and issubclass(ftyp_base.element_cls(), uint8) + ) + if issubclass(ftyp, List): + if issubclass(ftyp_base, ByteList): + return ( + ftyp.limit() == ftyp_base.limit() + and issubclass(ftyp.element_cls(), uint8) + ) + return ( + issubclass(ftyp_base, List) + and ftyp.limit() == ftyp_base.limit() + and has_compatible_merkleization(ftyp.element_cls(), ftyp_base.element_cls()) + ) + if issubclass(ftyp, Vector): + if issubclass(ftyp_base, ByteVector): + return ( + ftyp.vector_length() == ftyp_base.vector_length() + and issubclass(ftyp.element_cls(), uint8) + ) + return ( + issubclass(ftyp_base, Vector) + and ftyp.vector_length() == ftyp_base.vector_length() + and has_compatible_merkleization(ftyp.element_cls(), ftyp_base.element_cls()) + ) + if issubclass(ftyp, Container): + if not issubclass(ftyp_base, Container): + return False + fields = ftyp.fields() + fields_base = ftyp_base.fields() + if len(fields) != len(fields_base): + return False + for fkey, t in fields.items(): + if fkey not in fields_base: + return False + if not has_compatible_merkleization(t, fields_base[fkey]): + return False + return True + if issubclass(ftyp, StableContainer): + if not issubclass(ftyp_base, StableContainer): + return False + if ftyp.N != ftyp_base.N: + return False + fields = ftyp.fields() + fields_base = ftyp_base.fields() + if len(fields) != len(fields_base): + return False + for fkey, t in fields.items(): + if fkey not in fields_base: + return False + if not has_compatible_merkleization(t, fields_base[fkey]): + return False + return True + if issubclass(ftyp, Profile): + if issubclass(ftyp_base, StableContainer): + return has_compatible_merkleization(ftyp.B, ftyp_base) + if not issubclass(ftyp_base, Profile): + return False + if not has_compatible_merkleization(ftyp.B, ftyp_base.B): + return False + fields = ftyp.fields() + fields_base = ftyp_base.fields() + if len(fields) != len(fields_base): + return False + for fkey, (t, _) in fields.items(): + if fkey not in fields_base: + return False + (t_base, _) = fields_base[fkey] + if not has_compatible_merkleization(t, t_base): + return False + return True + return False + class ProfileMeta(ViewMeta): def __new__(cls, name, bases, dct): return super().__new__(cls, name, bases, dct, b=b) @@ -356,11 +471,7 @@ def __init_subclass__(cls, **kwargs): f'`{cls.__name__}` fields must exist in the base type ' f'but `{fkey}` is not defined in `{cls.B.__name__}`' ) - if issubclass(cls.B, StableContainer): - (findex, ftyp) = cls.B._field_indices[fkey] - else: - findex = cls.B._field_indices[fkey] - ftyp = cls.B.fields()[fkey] + (findex, ftyp) = cls.B._field_indices[fkey] if findex <= last_findex: raise TypeError( f'`{cls.__name__}` fields must have the same order as in the base type ' @@ -373,17 +484,8 @@ def __init_subclass__(cls, **kwargs): and type(None) in get_args(t) ) if fopt: - if not issubclass(cls.B, StableContainer): - raise TypeError( - f'`{cls.__name__}.{fkey}` cannot be `Optional[T]` ' - f'as base type `{cls.B.__name__}` is not a `StableContainer`' - ) t = get_args(t)[0] if get_args(t)[0] is not type(None) else get_args(t)[1] - if t == ftyp: - pass - elif issubclass(t, Profile) and t.B == ftyp: - pass - else: + if not has_compatible_merkleization(t, ftyp): raise TypeError( f'`{cls.__name__}.{fkey}` has type `{t.__name__}`, incompatible ' f'with base field `{cls.B.__name__}.{fkey}` of type `{ftyp.__name__}`' @@ -391,16 +493,6 @@ def __init_subclass__(cls, **kwargs): cls._field_indices[fkey] = (findex, t, fopt) if fopt: cls._o += 1 - if ( - not issubclass(cls.B, StableContainer) - and len(cls._field_indices) != len(cls.B._field_indices) - ): - for fkey, (findex, ftyp) in cls.B._field_indices.items(): - if fkey not in cls._field_indices: - raise TypeError( - f'`{cls.__name__}.{fkey}` of type `{ftyp.__name__}` is required ' - f'as base type `{cls.B.__name__}` is not a `StbleContainer`' - ) ProfileView.__name__ = ProfileView.type_repr() return ProfileView @@ -458,14 +550,10 @@ def item_elem_cls(cls, i: int) -> Type[View]: return cls.B.item_elem_cls(i) def active_fields(self) -> Bitvector: - if not issubclass(self.__class__.B, StableContainer): - raise Exception(f'`active_fields` requires `Profile` with `StableContainer` base') active_fields_node = super().get_backing().get_right() return Bitvector[self.__class__.B.N].view_from_backing(active_fields_node) def optional_fields(self) -> Bitvector: - if not issubclass(self.__class__.B, StableContainer): - raise Exception(f'`optional_fields` requires `Profile` with `StableContainer` base') if self.__class__._o == 0: raise Exception(f'`{self.__class__.__name__}` does not have any `Optional[T]` fields') active_fields = self.active_fields() @@ -491,12 +579,7 @@ def __getattr__(self, item): except KeyError: raise AttributeError(f'Unknown field `{item}`') - if issubclass(self.__class__.B, StableContainer): - value = stable_get(self, findex, ftyp, self.__class__.B.N) - else: - value = super().get(findex) - if not isinstance(value, ftyp): - value = ftyp(backing=value.get_backing()) + value = stable_get(self, findex, ftyp, self.__class__.B.N) assert value is not None or fopt return value @@ -511,11 +594,7 @@ def __setattr__(self, key, value): if value is None and not fopt: raise ValueError(f'Field `{key}` is required and cannot be set to `None`') - - if issubclass(self.__class__.B, StableContainer): - stable_set(self, findex, ftyp, self.__class__.B.N, value) - else: - super().set(findex, value) + stable_set(self, findex, ftyp, self.__class__.B.N, value) def __repr__(self): return f'{self.__class__.type_repr()}:\n' + '\n'.join( @@ -598,15 +677,11 @@ def serialize(self, stream: BinaryIO) -> int: if has_dyn_fields: temp_dyn_stream = io.BytesIO() - if issubclass(self.__class__.B, StableContainer): - data = super().get_backing().get_left() - active_fields = self.active_fields() - n = self.__class__.B.N - else: - data = super().get_backing() - n = len(self.__class__.B.fields()) + data = super().get_backing().get_left() + active_fields = self.active_fields() + n = self.__class__.B.N for (findex, ftyp, _) in self.__class__._field_indices.values(): - if issubclass(self.__class__.B, StableContainer) and not active_fields.get(findex): + if not active_fields.get(findex): continue fnode = data.getter(2**get_depth(n) + findex) v = ftyp.view_from_backing(fnode) @@ -633,8 +708,4 @@ def key_to_static_gindex(cls, key: Any) -> Gindex: if key == '__active_fields__': return RIGHT_GINDEX (findex, _, _) = cls._field_indices[key] - if issubclass(cls.B, StableContainer): - return 2**get_depth(cls.B.N) * 2 + findex - else: - n = len(cls.B.fields()) - return 2**get_depth(n) + findex + return 2**get_depth(cls.B.N) * 2 + findex diff --git a/remerkleable/test_impl.py b/remerkleable/test_impl.py index 61fe485..966bfb2 100644 --- a/remerkleable/test_impl.py +++ b/remerkleable/test_impl.py @@ -479,34 +479,28 @@ class B(Container): def test_stable_container(): - # Defines the common merkleization format and a portable serialization format class Shape(StableContainer[4]): side: Optional[uint16] color: Optional[uint8] radius: Optional[uint16] - # Inherits merkleization format from `Shape`, but is serialized more compactly class Square(Profile[Shape]): side: uint16 color: uint8 - # Inherits merkleization format from `Shape`, but is serialized more compactly class Circle(Profile[Shape]): color: uint8 radius: uint16 - # Defines a container with immutable scheme that contains two `StableContainer` class ShapePair(Container): shape_1: Shape shape_2: Shape - # Inherits merkleization format from `ShapePair`, and serializes more compactly - class SquarePair(Profile[ShapePair]): + class SquarePair(Container): shape_1: Square shape_2: Square - # Inherits merkleization format from `ShapePair`, and serializes more compactly - class CirclePair(Profile[ShapePair]): + class CirclePair(Container): shape_1: Circle shape_2: Circle @@ -515,6 +509,7 @@ class ShapePayload(Container): side: uint16 color: uint8 radius: uint16 + class ShapeRepr(Container): value: ShapePayload active_fields: Bitvector[4] From f0ffb4938794fed6623ca82c5a2fcb20ed2d57fd Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Tue, 4 Jun 2024 20:44:03 +0200 Subject: [PATCH 23/23] Compare field order in `Profile` compatibility check --- remerkleable/stable_container.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index 44d9ba4..667dc27 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -413,10 +413,10 @@ def has_compatible_merkleization(ftyp, ftyp_base) -> bool: fields_base = ftyp_base.fields() if len(fields) != len(fields_base): return False - for fkey, t in fields.items(): - if fkey not in fields_base: + for (fkey, t), (fkey_b, t_b) in zip(fields.items(), fields_base.items()): + if fkey != fkey_b: return False - if not has_compatible_merkleization(t, fields_base[fkey]): + if not has_compatible_merkleization(t, t_b): return False return True if issubclass(ftyp, StableContainer): @@ -428,10 +428,10 @@ def has_compatible_merkleization(ftyp, ftyp_base) -> bool: fields_base = ftyp_base.fields() if len(fields) != len(fields_base): return False - for fkey, t in fields.items(): - if fkey not in fields_base: + for (fkey, t), (fkey_b, t_b) in zip(fields.items(), fields_base.items()): + if fkey != fkey_b: return False - if not has_compatible_merkleization(t, fields_base[fkey]): + if not has_compatible_merkleization(t, t_b): return False return True if issubclass(ftyp, Profile): @@ -445,11 +445,10 @@ def has_compatible_merkleization(ftyp, ftyp_base) -> bool: fields_base = ftyp_base.fields() if len(fields) != len(fields_base): return False - for fkey, (t, _) in fields.items(): - if fkey not in fields_base: + for (fkey, (t, _)), (fkey_b, (t_b, _)) in zip(fields.items(), fields_base.items()): + if fkey != fkey_b: return False - (t_base, _) = fields_base[fkey] - if not has_compatible_merkleization(t, t_base): + if not has_compatible_merkleization(t, t_b): return False return True return False