diff --git a/clvm/SExp.py b/clvm/SExp.py index 398fe108..1db0f05b 100644 --- a/clvm/SExp.py +++ b/clvm/SExp.py @@ -167,9 +167,9 @@ def nullp(self): def as_int(self): return int_from_bytes(self.atom) - def as_bin(self): + def as_bin(self, *, allow_backrefs=False): f = io.BytesIO() - sexp_to_stream(self, f) + sexp_to_stream(self, f, allow_backrefs=allow_backrefs) return f.getvalue() @classmethod diff --git a/clvm/object_cache.py b/clvm/object_cache.py new file mode 100644 index 00000000..04ade03a --- /dev/null +++ b/clvm/object_cache.py @@ -0,0 +1,98 @@ +import hashlib + + +class ObjectCache: + """ + `ObjectCache` provides a way to calculate and cache values for each node + in a clvm object tree. It can be used to calculate the sha256 tree hash + for an object and save the hash for all the child objects for building + usage tables, for example. + + It also allows a function that's defined recursively on a clvm tree to + have a non-recursive implementation (as it keeps a stack of uncached + objects locally). + """ + + def __init__(self, f): + """ + `f`: Callable[ObjectCache, CLVMObject] -> Union[None, T] + + The function `f` is expected to calculate its T value recursively based + on the T values for the left and right child for a pair. For an atom, the + function f must calculate the T value directly. + + If a pair is passed and one of the children does not have its T value cached + in `ObjectCache` yet, return `None` and f will be called with each child in turn. + Don't recurse in f; that's part of the point of this function. + """ + self.f = f + self.lookup = dict() + + def get(self, obj): + obj_id = id(obj) + if obj_id not in self.lookup: + obj_list = [obj] + while obj_list: + node = obj_list.pop() + node_id = id(node) + if node_id not in self.lookup: + v = self.f(self, node) + if v is None: + if node.pair is None: + raise ValueError("f returned None for atom", node) + obj_list.append(node) + obj_list.append(node.pair[0]) + obj_list.append(node.pair[1]) + else: + self.lookup[node_id] = (v, node) + return self.lookup[obj_id][0] + + def contains(self, obj): + return id(obj) in self.lookup + + +def treehash(cache, obj): + """ + This function can be fed to `ObjectCache` to calculate the sha256 tree + hash for all objects in a tree. + """ + if obj.pair: + left, right = obj.pair + + # ensure both `left` and `right` have cached values + if cache.contains(left) and cache.contains(right): + left_hash = cache.get(left) + right_hash = cache.get(right) + return hashlib.sha256(b"\2" + left_hash + right_hash).digest() + return None + return hashlib.sha256(b"\1" + obj.atom).digest() + + +def serialized_length(cache, obj): + """ + This function can be fed to `ObjectCache` to calculate the serialized + length for all objects in a tree. + """ + if obj.pair: + left, right = obj.pair + + # ensure both `left` and `right` have cached values + if cache.contains(left) and cache.contains(right): + left_length = cache.get(left) + right_length = cache.get(right) + return 1 + left_length + right_length + return None + lb = len(obj.atom) + if lb == 0 or (lb == 1 and obj.atom[0] < 128): + return 1 + if lb < 0x40: + return 1 + lb + if lb < 0x2000: + return 2 + lb + if lb < 0x100000: + return 3 + lb + if lb < 0x8000000: + return 4 + lb + if lb < 0x400000000: + return 5 + lb + raise ValueError("atom of size %d too long" % lb) diff --git a/clvm/read_cache_lookup.py b/clvm/read_cache_lookup.py new file mode 100644 index 00000000..7920a591 --- /dev/null +++ b/clvm/read_cache_lookup.py @@ -0,0 +1,178 @@ +from collections import Counter +from typing import Optional, List, Set, Tuple + +import hashlib + + +LEFT = 0 +RIGHT = 1 + + +class ReadCacheLookup: + """ + When deserializing a clvm object, a stack of deserialized child objects + is created, which can be used with back-references. A `ReadCacheLookup` keeps + track of the state of this stack and all child objects under each root + node in the stack so that we can quickly determine if a relevant + back-reference is available. + + In other words, if we've already serialized an object with tree hash T, + and we encounter another object with that tree hash, we don't re-serialize + it, but rather include a back-reference to it. This data structure lets + us quickly determine which back-reference has the shortest path. + + Note that there is a counter. This is because the stack contains some + child objects that are transient, and no longer appear in the stack + at later times in the parsing. We don't want to waste time looking for + these objects that no longer exist, so we reference-count them. + + All hashes correspond to sha256 tree hashes. + """ + + def __init__(self): + """ + Create a new `ReadCacheLookup` object with just the null terminator + (ie. an empty list of objects). + """ + self.root_hash = hashlib.sha256(b"\1").digest() + self.read_stack = [] + self.count = Counter() + self.parent_paths_for_child = {} + + def push(self, obj_hash: bytes) -> None: + """ + This function is used to note that an object with the given hash has just + been pushed to the read stack, and update the lookups as appropriate. + """ + # we add two new entries: the new root of the tree, and this object (by id) + # new_root: (obj_hash, old_root) + new_root_hash = hashlib.sha256(b"\2" + obj_hash + self.root_hash).digest() + + self.read_stack.append((obj_hash, self.root_hash)) + + self.count.update([obj_hash, new_root_hash]) + + new_parent_to_old_root = (new_root_hash, LEFT) + self.parent_paths_for_child.setdefault(obj_hash, list()).append( + new_parent_to_old_root + ) + + new_parent_to_id = (new_root_hash, RIGHT) + self.parent_paths_for_child.setdefault(self.root_hash, list()).append( + new_parent_to_id + ) + self.root_hash = new_root_hash + + def pop(self) -> Tuple[bytes, bytes]: + """ + This function is used to note that the top object has just been popped + from the read stack. Return the 2-tuple of the child hashes. + """ + item = self.read_stack.pop() + self.count[item[0]] -= 1 + self.count[self.root_hash] -= 1 + self.root_hash = item[1] + return item + + def pop2_and_cons(self) -> None: + """ + This function is used to note that a "pop-and-cons" operation has just + happened. We remove two objects, cons them together, and push the cons, + updating the internal look-ups as necessary. + """ + # we remove two items: the right side of each left/right pair + right = self.pop() + left = self.pop() + + self.count.update([left[0], right[0]]) + + new_root_hash = hashlib.sha256(b"\2" + left[0] + right[0]).digest() + + self.parent_paths_for_child.setdefault(left[0], list()).append( + (new_root_hash, LEFT) + ) + self.parent_paths_for_child.setdefault(right[0], list()).append( + (new_root_hash, RIGHT) + ) + self.push(new_root_hash) + + def find_paths(self, obj_hash: bytes, serialized_length: int) -> Set[bytes]: + """ + This function looks for a path from the root to a child node with a given hash + by using the read cache. + """ + valid_paths = set() + if serialized_length < 3: + return valid_paths + + seen_ids = set() + + max_bytes_for_path_encoding = serialized_length - 2 + # 1 byte for 0xfe, 1 min byte for savings + + max_path_length = max_bytes_for_path_encoding * 8 - 1 + seen_ids.add(obj_hash) + + partial_paths = [(obj_hash, [])] + + while partial_paths: + new_seen_ids = set(seen_ids) + new_partial_paths = [] + for (node, path) in partial_paths: + if node == self.root_hash: + valid_paths.add(reversed_path_to_bytes(path)) + continue + + parent_paths = self.parent_paths_for_child.get(node) + + if parent_paths: + for (parent, direction) in parent_paths: + if self.count[parent] > 0 and parent not in seen_ids: + new_path = list(path) + new_path.append(direction) + if len(new_path) > max_path_length: + return set() + new_partial_paths.append((parent, new_path)) + new_seen_ids.add(parent) + partial_paths = new_partial_paths + if valid_paths: + return valid_paths + seen_ids = frozenset(new_seen_ids) + return valid_paths + + def find_path(self, obj_hash: bytes, serialized_length: int) -> Optional[bytes]: + r = self.find_paths(obj_hash, serialized_length) + return min(r) if len(r) > 0 else None + + +def reversed_path_to_bytes(path: List[int]) -> bytes: + """ + Convert a list of 0/1 (for left/right) values to a path expected by clvm. + + Reverse the list; convert to a binary number; prepend a 1; break into bytes. + + [] => bytes([0b1]) + [0] => bytes([0b10]) + [1] => bytes([0b11]) + [0, 0] => bytes([0b100]) + [0, 1] => bytes([0b101]) + [1, 0] => bytes([0b110]) + [1, 1] => bytes([0b111]) + [0, 0, 1] => bytes([0b1001]) + [1, 1, 1, 1, 0, 0, 0, 0, 1] => bytes([0b11, 0b11100001]) + """ + + byte_count = (len(path) + 1 + 7) >> 3 + v = bytearray(byte_count) + index = byte_count - 1 + mask = 1 + for p in reversed(path): + if p: + v[index] |= mask + if mask == 0x80: + index -= 1 + mask = 1 + else: + mask <<= 1 + v[index] |= mask + return bytes(v) diff --git a/clvm/serialize.py b/clvm/serialize.py index d685794a..7e5b892d 100644 --- a/clvm/serialize.py +++ b/clvm/serialize.py @@ -1,6 +1,7 @@ # decoding: # read a byte # if it's 0x80, it's nil (which might be same as 0) +# if it's 0xfe, it's a back-reference. Read an atom, and treat it as a path in the cache tree. # if it's 0xff, it's a cons box. Read two items, build cons # otherwise, number of leading set bits is length in bytes to read size # For example, if the bit fields of the first byte read are: @@ -12,25 +13,80 @@ # If the first byte read is one of the following: # 1000 0000 -> 0 bytes : nil # 0000 0000 -> 1 byte : zero (b'\x00') + +from typing import Iterator + import io -from .CLVMObject import CLVMObject +from .read_cache_lookup import ReadCacheLookup +from .object_cache import ObjectCache, treehash, serialized_length MAX_SINGLE_BYTE = 0x7F +BACK_REFERENCE = 0xFE CONS_BOX_MARKER = 0xFF -def sexp_to_byte_iterator(sexp): +def sexp_to_byte_iterator(sexp, *, allow_backrefs=False) -> Iterator[bytes]: + if allow_backrefs: + yield from sexp_to_byte_iterator_with_backrefs(sexp) + return + todo_stack = [sexp] while todo_stack: sexp = todo_stack.pop() - pair = sexp.as_pair() + pair = sexp.pair if pair: yield bytes([CONS_BOX_MARKER]) todo_stack.append(pair[1]) todo_stack.append(pair[0]) else: - yield from atom_to_byte_iterator(sexp.as_atom()) + yield from atom_to_byte_iterator(sexp.atom) + + +def sexp_to_byte_iterator_with_backrefs(sexp) -> Iterator[bytes]: + + # in `read_op_stack`: + # "P" = "push" + # "C" = "pop two objects, create and push a new cons with them" + + read_op_stack = ["P"] + + write_stack = [sexp] + + read_cache_lookup = ReadCacheLookup() + + thc = ObjectCache(treehash) + slc = ObjectCache(serialized_length) + + while write_stack: + node_to_write = write_stack.pop() + op = read_op_stack.pop() + assert op == "P" + + node_serialized_length = slc.get(node_to_write) + + node_tree_hash = thc.get(node_to_write) + path = read_cache_lookup.find_path(node_tree_hash, node_serialized_length) + if path: + yield bytes([BACK_REFERENCE]) + yield from atom_to_byte_iterator(path) + read_cache_lookup.push(node_tree_hash) + elif node_to_write.pair: + left, right = node_to_write.pair + yield bytes([CONS_BOX_MARKER]) + write_stack.append(right) + write_stack.append(left) + read_op_stack.append("C") + read_op_stack.append("P") + read_op_stack.append("P") + else: + atom = node_to_write.atom + yield from atom_to_byte_iterator(atom) + read_cache_lookup.push(node_tree_hash) + + while read_op_stack[-1:] == ["C"]: + read_op_stack.pop() + read_cache_lookup.pop2_and_cons() def atom_to_byte_iterator(as_atom): @@ -74,11 +130,32 @@ def atom_to_byte_iterator(as_atom): yield as_atom -def sexp_to_stream(sexp, f): - for b in sexp_to_byte_iterator(sexp): +def sexp_to_stream(sexp, f, *, allow_backrefs=False): + for b in sexp_to_byte_iterator(sexp, allow_backrefs=allow_backrefs): f.write(b) +def traverse_path(obj, path: bytes, to_sexp): + path_as_int = int.from_bytes(path, "big") + if path_as_int == 0: + return to_sexp(b"") + + while path_as_int > 1: + if obj.pair is None: + raise ValueError("path into atom", obj) + obj = obj.pair[path_as_int & 1] + path_as_int >>= 1 + + return to_sexp(obj) + + +def _op_cons(op_stack, val_stack, f, to_sexp): + right, val_stack = val_stack.pair + left, val_stack = val_stack.pair + new_cons = to_sexp((left, right)) + return to_sexp((new_cons, val_stack)) + + def _op_read_sexp(op_stack, val_stack, f, to_sexp): blob = f.read(1) if len(blob) == 0: @@ -88,24 +165,38 @@ def _op_read_sexp(op_stack, val_stack, f, to_sexp): op_stack.append(_op_cons) op_stack.append(_op_read_sexp) op_stack.append(_op_read_sexp) - return - val_stack.append(_atom_from_stream(f, b, to_sexp)) + return val_stack + return to_sexp((_atom_from_stream(f, b, to_sexp), val_stack)) -def _op_cons(op_stack, val_stack, f, to_sexp): - right = val_stack.pop() - left = val_stack.pop() - val_stack.append(to_sexp((left, right))) +def _op_read_sexp_allow_backrefs(op_stack, val_stack, f, to_sexp): + blob = f.read(1) + if len(blob) == 0: + raise ValueError("bad encoding") + b = blob[0] + if b == CONS_BOX_MARKER: + op_stack.append(_op_cons) + op_stack.append(_op_read_sexp_allow_backrefs) + op_stack.append(_op_read_sexp_allow_backrefs) + return val_stack + if b == BACK_REFERENCE: + blob = f.read(1) + if len(blob) == 0: + raise ValueError("bad encoding") + path = _atom_from_stream(f, blob[0], lambda x: x) + backref = traverse_path(val_stack, path, to_sexp) + return to_sexp((backref, val_stack)) + return to_sexp((_atom_from_stream(f, b, to_sexp), val_stack)) -def sexp_from_stream(f, to_sexp): - op_stack = [_op_read_sexp] - val_stack = [] +def sexp_from_stream(f, to_sexp, *, allow_backrefs=False): + op_stack = [_op_read_sexp_allow_backrefs if allow_backrefs else _op_read_sexp] + val_stack = to_sexp(b"") while op_stack: func = op_stack.pop() - func(op_stack, val_stack, f, CLVMObject) - return to_sexp(val_stack.pop()) + val_stack = func(op_stack, val_stack, f, to_sexp) + return val_stack.pair[0] def _op_consume_sexp(f): diff --git a/setup.py b/setup.py index a1a86be7..187ecc6c 100755 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ setup( name="clvm", - packages=["clvm",], + packages=["clvm"], author="Chia Network, Inc.", author_email="hello@chia.net", url="https://github.com/Chia-Network/clvm", diff --git a/tests/generator.bin.gz b/tests/generator.bin.gz new file mode 100644 index 00000000..b120ff3e Binary files /dev/null and b/tests/generator.bin.gz differ diff --git a/tests/object_cache_test.py b/tests/object_cache_test.py new file mode 100644 index 00000000..f4c30f88 --- /dev/null +++ b/tests/object_cache_test.py @@ -0,0 +1,41 @@ +import unittest + +from clvm.object_cache import ObjectCache, treehash, serialized_length + +from clvm_tools.binutils import assemble + + +class ObjectCacheTest(unittest.TestCase): + def check(self, obj_text, expected_hash, expected_length): + obj = assemble(obj_text) + th = ObjectCache(treehash) + self.assertEqual(th.get(obj).hex(), expected_hash) + sl = ObjectCache(serialized_length) + self.assertEqual(sl.get(obj), expected_length) + + def test_various(self): + self.check( + "0x00", + "47dc540c94ceb704a23875c11273e16bb0b8a87aed84de911f2133568115f254", + 1, + ) + + self.check( + "0", "4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a", 1 + ) + + self.check( + "foo", "0080b50a51ecd0ccfaaa4d49dba866fe58724f18445d30202bafb03e21eef6cb", 4 + ) + + self.check( + "(foo . bar)", + "c518e45ae6a7b4146017b7a1d81639051b132f1f5572ce3088a3898a9ed1280b", + 9, + ) + + self.check( + "(this is a longer test of a deeper tree)", + "0a072d7d860d77d8e290ced0fdb29a271198ca3db54d701c45d831e3aae6422c", + 47, + ) diff --git a/tests/read_cache_lookup_test.py b/tests/read_cache_lookup_test.py new file mode 100644 index 00000000..81c63845 --- /dev/null +++ b/tests/read_cache_lookup_test.py @@ -0,0 +1,105 @@ +import unittest + +from clvm import to_sexp_f +from clvm.read_cache_lookup import ReadCacheLookup +from clvm.object_cache import ObjectCache, treehash + + +class ReadCacheLookupTest(unittest.TestCase): + def test_various(self): + rcl = ReadCacheLookup() + treehasher = ObjectCache(treehash) + + # rcl = () + nil = to_sexp_f(b"") + nil_hash = treehasher.get(nil) + self.assertEqual(rcl.root_hash, nil_hash) + + foo = to_sexp_f(b"foo") + foo_hash = treehasher.get(foo) + rcl.push(foo_hash) + + # rcl = (foo . 0) + + current_stack = to_sexp_f([foo]) + current_stack_hash = treehasher.get(current_stack) + + self.assertEqual(rcl.root_hash, current_stack_hash) + self.assertEqual(rcl.find_path(foo_hash, serialized_length=20), bytes([2])) + self.assertEqual(rcl.find_path(nil_hash, serialized_length=20), bytes([3])) + self.assertEqual( + rcl.find_path(current_stack_hash, serialized_length=20), bytes([1]) + ) + + bar = to_sexp_f(b"bar") + bar_hash = treehasher.get(bar) + rcl.push(bar_hash) + + # rcl = (bar foo) + + current_stack = to_sexp_f([bar, foo]) + current_stack_hash = treehasher.get(current_stack) + foo_list_hash = treehasher.get(to_sexp_f([b"foo"])) + self.assertEqual(rcl.root_hash, current_stack_hash) + self.assertEqual(rcl.find_path(bar_hash, serialized_length=20), bytes([2])) + self.assertEqual(rcl.find_path(foo_list_hash, serialized_length=20), bytes([3])) + self.assertEqual(rcl.find_path(foo_hash, serialized_length=20), bytes([5])) + self.assertEqual(rcl.find_path(nil_hash, serialized_length=20), bytes([7])) + self.assertEqual( + rcl.find_path(current_stack_hash, serialized_length=20), bytes([1]) + ) + self.assertEqual(rcl.count[foo_list_hash], 1) + + rcl.pop2_and_cons() + # rcl = ((foo . bar) . 0) + + current_stack = to_sexp_f([(foo, bar)]) + current_stack_hash = treehasher.get(current_stack) + self.assertEqual(rcl.root_hash, current_stack_hash) + + # we no longer have `(foo . 0)` in the read stack + # check that its count is zero + self.assertEqual(rcl.count[foo_list_hash], 0) + + self.assertEqual(rcl.find_path(bar_hash, serialized_length=20), bytes([6])) + self.assertEqual(rcl.find_path(foo_list_hash, serialized_length=20), None) + self.assertEqual(rcl.find_path(foo_hash, serialized_length=20), bytes([4])) + self.assertEqual(rcl.find_path(nil_hash, serialized_length=20), bytes([3])) + self.assertEqual( + rcl.find_path(current_stack_hash, serialized_length=20), bytes([1]) + ) + + rcl.push(foo_hash) + rcl.push(foo_hash) + rcl.pop2_and_cons() + + # rcl = ((foo . foo) (foo . bar)) + + current_stack = to_sexp_f([(foo, foo), (foo, bar)]) + current_stack_hash = treehasher.get(current_stack) + self.assertEqual(rcl.root_hash, current_stack_hash) + self.assertEqual(rcl.find_path(bar_hash, serialized_length=20), bytes([13])) + self.assertEqual(rcl.find_path(foo_list_hash, serialized_length=20), None) + self.assertEqual(rcl.find_path(foo_hash, serialized_length=20), bytes([4])) + self.assertEqual(rcl.find_path(nil_hash, serialized_length=20), bytes([7])) + + # find BOTH minimal paths to `foo` + self.assertEqual( + rcl.find_paths(foo_hash, serialized_length=20), + set([bytes([4]), bytes([6])]), + ) + + rcl = ReadCacheLookup() + rcl.push(foo_hash) + rcl.push(foo_hash) + rcl.pop2_and_cons() + rcl.push(foo_hash) + rcl.push(foo_hash) + rcl.pop2_and_cons() + rcl.pop2_and_cons() + # rcl = ((foo . foo) . (foo . foo)) + # find ALL minimal paths to `foo` + self.assertEqual( + rcl.find_paths(foo_hash, serialized_length=20), + set([bytes([8]), bytes([10]), bytes([12]), bytes([14])]), + ) diff --git a/tests/serialize_test.py b/tests/serialize_test.py index 786f1c95..3ebebd23 100644 --- a/tests/serialize_test.py +++ b/tests/serialize_test.py @@ -1,8 +1,14 @@ +import gzip import io import unittest from clvm import to_sexp_f -from clvm.serialize import (sexp_from_stream, sexp_buffer_from_stream, atom_to_byte_iterator) +from clvm.serialize import ( + _atom_from_stream, + sexp_from_stream, + sexp_buffer_from_stream, + atom_to_byte_iterator, +) TEXT = b"the quick brown fox jumps over the lazy dogs" @@ -13,12 +19,12 @@ def __init__(self, b): self.buf = b def read(self, n): - ret = b'' + ret = b"" while n > 0 and len(self.buf) > 0: ret += self.buf[0:1] self.buf = self.buf[1:] n -= 1 - ret += b' ' * n + ret += b" " * n return ret @@ -27,6 +33,24 @@ def __len__(self): return 0x400000001 +def has_backrefs(blob: bytes) -> bool: + """ + Return `True` iff blob has a backref in it. + """ + f = io.BytesIO(blob) + obj_count = 1 + while obj_count > 0: + b = f.read(1)[0] + if b == 0xFE: + return True + if b == 0xFF: + obj_count += 1 + else: + _atom_from_stream(f, b, lambda x: x) + obj_count -= 1 + return False + + class SerializeTest(unittest.TestCase): def check_serde(self, s): v = to_sexp_f(s) @@ -44,6 +68,22 @@ def check_serde(self, s): buf = sexp_buffer_from_stream(io.BytesIO(b)) self.assertEqual(buf, b) + # now turn on backrefs and make sure everything still works + + b2 = v.as_bin(allow_backrefs=True) + self.assertTrue(len(b2) <= len(b)) + if has_backrefs(b2) or len(b2) < len(b): + # if we have any backrefs, ensure they actually save space + self.assertTrue(len(b2) < len(b)) + io_b2 = io.BytesIO(b2) + self.assertRaises(ValueError, lambda: sexp_from_stream(io_b2, to_sexp_f)) + io_b2 = io.BytesIO(b2) + v2 = sexp_from_stream(io_b2, to_sexp_f, allow_backrefs=True) + self.assertEqual(v2, s) + b3 = v2.as_bin() + self.assertEqual(b, b3) + return b2 + def test_zero(self): v = to_sexp_f(b"\x00") self.assertEqual(v.as_bin(), b"\x00") @@ -79,7 +119,7 @@ def test_long_blobs(self): def test_blob_limit(self): with self.assertRaises(ValueError): for b in atom_to_byte_iterator(LargeAtom()): - print('%02x' % b) + print("%02x" % b) def test_very_long_blobs(self): for size in [0x40, 0x2000, 0x100000, 0x8000000]: @@ -100,7 +140,7 @@ def test_very_deep_tree(self): self.check_serde(s) def test_deserialize_empty(self): - bytes_in = b'' + bytes_in = b"" with self.assertRaises(ValueError): sexp_from_stream(io.BytesIO(bytes_in), to_sexp_f) @@ -110,7 +150,7 @@ def test_deserialize_empty(self): def test_deserialize_truncated_size(self): # fe means the total number of bytes in the length-prefix is 7 # one for each bit set. 5 bytes is too few - bytes_in = b'\xfe ' + bytes_in = b"\xfe " with self.assertRaises(ValueError): sexp_from_stream(io.BytesIO(bytes_in), to_sexp_f) @@ -120,7 +160,7 @@ def test_deserialize_truncated_size(self): def test_deserialize_truncated_blob(self): # this is a complete length prefix. The blob is supposed to be 63 bytes # the blob itself is truncated though, it's less than 63 bytes - bytes_in = b'\xbf ' + bytes_in = b"\xbf " with self.assertRaises(ValueError): sexp_from_stream(io.BytesIO(bytes_in), to_sexp_f) @@ -134,10 +174,57 @@ def test_deserialize_large_blob(self): # we don't support blobs this large, and we should fail immediately when # exceeding the max blob size, rather than trying to read this many # bytes from the stream - bytes_in = b'\xfe' + b'\xff' * 6 + bytes_in = b"\xfe" + b"\xff" * 6 with self.assertRaises(ValueError): sexp_from_stream(InfiniteStream(bytes_in), to_sexp_f) with self.assertRaises(ValueError): sexp_buffer_from_stream(InfiniteStream(bytes_in)) + + def test_deserialize_generator(self): + blob = gzip.GzipFile("tests/generator.bin.gz").read() + s = sexp_from_stream(io.BytesIO(blob), to_sexp_f) + b = self.check_serde(s) + assert len(b) == 19124 + + def test_deserialize_bomb(self): + def make_bomb(depth): + bomb = TEXT + for _ in range(depth): + bomb = to_sexp_f((bomb, bomb)) + return bomb + + bomb_10 = make_bomb(10) + b10_1 = bomb_10.as_bin(allow_backrefs=False) + b10_2 = bomb_10.as_bin(allow_backrefs=True) + self.assertEqual(len(b10_1), 47103) + self.assertEqual(len(b10_2), 75) + + bomb_20 = make_bomb(20) + b20_1 = bomb_20.as_bin(allow_backrefs=False) + b20_2 = bomb_20.as_bin(allow_backrefs=True) + self.assertEqual(len(b20_1), 48234495) + self.assertEqual(len(b20_2), 105) + + bomb_30 = make_bomb(30) + # do not uncomment the next line unless you want to run out of memory + # b30_1 = bomb_30.as_bin(allow_backrefs=False) + b30_2 = bomb_30.as_bin(allow_backrefs=True) + + # self.assertEqual(len(b30_1), 1) + self.assertEqual(len(b30_2), 135) + + def test_specific_tree(self): + sexp1 = to_sexp_f((("AAA", "BBB"), ("CCC", "AAA"))) + serialized_sexp1_v1 = sexp1.as_bin(allow_backrefs=False) + serialized_sexp1_v2 = sexp1.as_bin(allow_backrefs=True) + self.assertEqual(len(serialized_sexp1_v1), 19) + self.assertEqual(len(serialized_sexp1_v2), 17) + deserialized_sexp1_v1 = sexp_from_stream( + io.BytesIO(serialized_sexp1_v1), to_sexp_f, allow_backrefs=False + ) + deserialized_sexp1_v2 = sexp_from_stream( + io.BytesIO(serialized_sexp1_v2), to_sexp_f, allow_backrefs=True + ) + self.assertTrue(deserialized_sexp1_v1 == deserialized_sexp1_v2)