-
Notifications
You must be signed in to change notification settings - Fork 274
/
Copy pathcairo_pie.py
499 lines (430 loc) · 19.2 KB
/
cairo_pie.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
"""
A CairoPie represents a position independent execution of a Cairo program.
"""
import contextlib
import copy
import dataclasses
import functools
import io
import json
import math
import zipfile
from dataclasses import field
from typing import Any, ClassVar, Dict, List, Mapping, Optional, Tuple, Type
import marshmallow
import marshmallow.fields as mfields
import marshmallow_dataclass
from starkware.cairo.lang.compiler.program import StrippedProgram, is_valid_builtin_name
from starkware.cairo.lang.vm.memory_dict import MemoryDict, RelocateValueFunc
from starkware.cairo.lang.vm.memory_segments import is_valid_memory_addr, is_valid_memory_value
from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue, relocate_value
from starkware.python.utils import add_counters, multiply_counter_by_scalar, sub_counters
from starkware.starkware_utils.marshmallow_dataclass_fields import additional_metadata
DEFAULT_CAIRO_PIE_VERSION = "1.0"
CURRENT_CAIRO_PIE_VERSION = "1.1"
MAX_N_STEPS = 2**30
@dataclasses.dataclass
class SegmentInfo:
"""
Segment index and size.
"""
index: int
size: int
def run_validity_checks(self):
assert isinstance(self.index, int) and 0 <= self.index < 2**30, "Invalid segment index."
assert isinstance(self.size, int) and 0 <= self.size < 2**30, "Invalid segment size."
@marshmallow_dataclass.dataclass
class CairoPieMetadata:
"""
Metadata of a PIE output.
"""
program: StrippedProgram
program_segment: SegmentInfo
execution_segment: SegmentInfo
ret_fp_segment: SegmentInfo
ret_pc_segment: SegmentInfo
builtin_segments: Dict[str, SegmentInfo]
extra_segments: List[SegmentInfo]
Schema: ClassVar[Type[marshmallow.Schema]] = marshmallow.Schema
@property
def field_bytes(self) -> int:
return math.ceil(self.program.prime.bit_length() / 8)
def validate_segment_order(self):
assert self.program_segment.index == 0, "Invalid segment index for program_segment."
assert self.execution_segment.index == 1, "Invalid segment index for execution_segment."
for expected_segment, (name, builtin_segment) in enumerate(
self.builtin_segments.items(), 2
):
assert builtin_segment.index == expected_segment, f"Invalid segment index for {name}."
n_builtins = len(self.builtin_segments)
assert (
self.ret_fp_segment.index == n_builtins + 2
), f"Invalid segment index for ret_fp_segment. {self.ret_fp_segment.index}"
assert (
self.ret_pc_segment.index == n_builtins + 3
), "Invalid segment index for ret_pc_segment."
for expected_segment, segment in enumerate(self.extra_segments, n_builtins + 4):
assert segment.index == expected_segment, "Invalid segment indices for extra_segments."
def all_segments(self) -> List[SegmentInfo]:
"""
Returns a list of all the segments.
"""
return [
self.program_segment,
self.execution_segment,
self.ret_fp_segment,
self.ret_pc_segment,
*self.builtin_segments.values(),
*self.extra_segments,
]
def segment_sizes(self) -> Dict[int, int]:
"""
Returns a map from segment index to its size.
"""
return {segment.index: segment.size for segment in self.all_segments()}
def run_validity_checks(self):
self.program.run_validity_checks()
assert isinstance(self.builtin_segments, dict) and all(
is_valid_builtin_name(name) for name in self.builtin_segments.keys()
), "Invalid builtin_segments."
assert isinstance(self.extra_segments, list), "Invalid type for extra_segments."
for segment_info in self.all_segments():
assert isinstance(segment_info, SegmentInfo), "Invalid type for segment_info."
segment_info.run_validity_checks()
assert self.program_segment.size == len(
self.program.data
), "Program length does not match the program segment size."
assert self.program.builtins == list(self.builtin_segments.keys()), (
f"Builtin list mismatch in builtin_segments. Builtins: {self.program.builtins}, "
f"segment keys: {list(self.builtin_segments.keys())}."
)
assert self.ret_fp_segment.size == 0, "Invalid segment size for ret_fp. Must be 0."
assert self.ret_pc_segment.size == 0, "Invalid segment size for ret_pc. Must be 0."
self.validate_segment_order()
@marshmallow_dataclass.dataclass
class ExecutionResources:
"""
Indicates how many steps the program should run, how many memory cells are used from each
builtin, and how many holes there are in the memory address space.
"""
n_steps: int
builtin_instance_counter: Dict[str, int]
n_memory_holes: int = field(
metadata=additional_metadata(marshmallow_field=mfields.Integer(load_default=0))
)
Schema: ClassVar[Type[marshmallow.Schema]] = marshmallow.Schema
def run_validity_checks(self):
assert (
isinstance(self.n_steps, int) and 1 <= self.n_steps < MAX_N_STEPS
), f"Invalid n_steps: {self.n_steps}."
assert (
isinstance(self.n_memory_holes, int) and 0 <= self.n_memory_holes < 2**30
), f"Invalid n_memory_holes: {self.n_memory_holes}."
assert isinstance(self.builtin_instance_counter, dict) and all(
is_valid_builtin_name(name) and isinstance(size, int) and 0 <= size < 2**30
for name, size in self.builtin_instance_counter.items()
), "Invalid builtin_instance_counter."
def __add__(self, other: "ExecutionResources") -> "ExecutionResources":
total_builtin_instance_counter = add_counters(
self.builtin_instance_counter, other.builtin_instance_counter
)
return ExecutionResources(
n_steps=self.n_steps + other.n_steps,
builtin_instance_counter=total_builtin_instance_counter,
n_memory_holes=self.n_memory_holes + other.n_memory_holes,
)
def __sub__(self, other: "ExecutionResources") -> "ExecutionResources":
diff_builtin_instance_counter = sub_counters(
self.builtin_instance_counter, other.builtin_instance_counter
)
return ExecutionResources(
n_steps=self.n_steps - other.n_steps,
builtin_instance_counter=diff_builtin_instance_counter,
n_memory_holes=self.n_memory_holes - other.n_memory_holes,
)
def __mul__(self, other: int) -> "ExecutionResources":
if not isinstance(other, int):
return NotImplemented
total_builtin_instance_counter = multiply_counter_by_scalar(
scalar=other, counter=self.builtin_instance_counter
)
return ExecutionResources(
n_steps=other * self.n_steps,
builtin_instance_counter=total_builtin_instance_counter,
n_memory_holes=other * self.n_memory_holes,
)
def __rmul__(self, other: int) -> "ExecutionResources":
return self * other
@classmethod
def empty(cls):
return cls(n_steps=0, builtin_instance_counter={}, n_memory_holes=0)
def copy(self) -> "ExecutionResources":
return copy.deepcopy(self)
def to_dict(self) -> Dict[str, int]:
return dict(
**self.builtin_instance_counter,
n_steps=self.n_steps + self.n_memory_holes,
)
def filter_unused_builtins(self) -> "ExecutionResources":
"""
Returns a copy of the execution resources where all the builtins with a usage counter
of 0 are omitted.
"""
return dataclasses.replace(
self,
builtin_instance_counter={
name: counter
for name, counter in self.builtin_instance_counter.items()
if counter > 0
},
)
@dataclasses.dataclass
class CairoPie:
"""
A CairoPie is a serializable object containing information about a run of a cairo program.
Using the information, one can 'relocate' segments of the run, to make another valid cairo run.
For example, this may be used to join a few cairo runs into one, by concatenating respective
segments.
"""
metadata: CairoPieMetadata
memory: MemoryDict
additional_data: Dict[str, Any]
execution_resources: ExecutionResources
version: Dict[str, str] = field(
default_factory=lambda: {"cairo_pie": CURRENT_CAIRO_PIE_VERSION}
)
METADATA_FILENAME = "metadata.json"
MEMORY_FILENAME = "memory.bin"
ADDITIONAL_DATA_FILENAME = "additional_data.json"
EXECUTION_RESOURCES_FILENAME = "execution_resources.json"
VERSION_FILENAME = "version.json"
OPTIONAL_FILES = [VERSION_FILENAME]
ALL_FILES = [
METADATA_FILENAME,
MEMORY_FILENAME,
ADDITIONAL_DATA_FILENAME,
EXECUTION_RESOURCES_FILENAME,
] + OPTIONAL_FILES
MAX_SIZE = 5 * 1024**3
@classmethod
def from_file(cls, fileobj) -> "CairoPie":
"""
Loads an instance of CairoPie from a file.
`fileobj` can be a path or a file object.
"""
if isinstance(fileobj, str):
fileobj = open(fileobj, "rb")
verify_zip_file_prefix(fileobj=fileobj)
with zipfile.ZipFile(fileobj) as zf:
cls.verify_zip_format(zf)
with zf.open(cls.METADATA_FILENAME, "r") as fp:
metadata = CairoPieMetadata.Schema().load(
json.loads(fp.read(cls.MAX_SIZE).decode("ascii"))
)
with zf.open(cls.MEMORY_FILENAME, "r") as fp:
memory = MemoryDict.deserialize(
data=fp.read(cls.MAX_SIZE),
field_bytes=metadata.field_bytes,
)
with zf.open(cls.ADDITIONAL_DATA_FILENAME, "r") as fp:
additional_data = json.loads(fp.read(cls.MAX_SIZE).decode("ascii"))
with zf.open(cls.EXECUTION_RESOURCES_FILENAME, "r") as fp:
execution_resources = ExecutionResources.Schema().load(
json.loads(fp.read(cls.MAX_SIZE).decode("ascii"))
)
version = {"cairo_pie": DEFAULT_CAIRO_PIE_VERSION}
if cls.VERSION_FILENAME in zf.namelist():
with zf.open(cls.VERSION_FILENAME, "r") as fp:
version = json.loads(fp.read(cls.MAX_SIZE).decode("ascii"))
return cls(metadata, memory, additional_data, execution_resources, version)
def merge_extra_segments(self) -> Tuple[List[SegmentInfo], Dict[int, RelocatableValue]]:
"""
Merges extra_segments to one segment.
Returns a tuple of the new extra_segments (which contains a single merged segment) and a
dictionary from old segment index to its offset in the new segment.
"""
assert len(self.metadata.extra_segments) > 0
# Take the index of the segment from the first merged segment.
new_segment_index = self.metadata.extra_segments[0].index
segment_offsets = {}
segments_accumulated_size = 0
for segment in self.metadata.extra_segments:
segment_offsets[segment.index] = RelocatableValue(
new_segment_index, segments_accumulated_size
)
segments_accumulated_size += segment.size
return (
[SegmentInfo(index=new_segment_index, size=segments_accumulated_size)],
segment_offsets,
)
def get_relocate_value_func(
self, segment_offsets: Optional[Mapping[int, MaybeRelocatable]]
) -> Optional[RelocateValueFunc]:
"""
Returns a relocate_value function that relocates values according to the given segment
offsets.
"""
if segment_offsets is None:
return None
return functools.partial(
relocate_value,
segment_offsets=segment_offsets,
prime=self.program.prime,
# The known segments (such as builtins) are missing since we do not want to relocate
# them.
allow_missing_segments=True,
)
def to_file(self, file, merge_extra_segments: bool = False):
extra_segments, segment_offsets = (
self.merge_extra_segments()
if merge_extra_segments and len(self.metadata.extra_segments) > 0
else (None, None)
)
metadata = self.metadata
if extra_segments is not None:
metadata = dataclasses.replace(metadata, extra_segments=extra_segments)
with zipfile.ZipFile(file, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
with zf.open(self.METADATA_FILENAME, "w") as fp:
fp.write(json.dumps(CairoPieMetadata.Schema().dump(metadata)).encode("ascii"))
with zf.open(self.MEMORY_FILENAME, "w", force_zip64=True) as fp:
fp.write(
self.memory.serialize(
field_bytes=self.metadata.field_bytes,
relocate_value=self.get_relocate_value_func(
segment_offsets=segment_offsets
),
)
)
with zf.open(self.ADDITIONAL_DATA_FILENAME, "w") as fp:
fp.write(json.dumps(self.additional_data).encode("ascii"))
with zf.open(self.EXECUTION_RESOURCES_FILENAME, "w") as fp:
fp.write(
json.dumps(ExecutionResources.Schema().dump(self.execution_resources)).encode(
"ascii"
)
)
with zf.open(self.VERSION_FILENAME, "w") as fp:
fp.write(json.dumps(self.version).encode("ascii"))
@classmethod
def deserialize(cls, cairo_pie_bytes: bytes) -> "CairoPie":
cairo_pie_file = io.BytesIO()
cairo_pie_file.write(cairo_pie_bytes)
return CairoPie.from_file(fileobj=cairo_pie_file)
def serialize(self) -> bytes:
cairo_pie_file = io.BytesIO()
self.to_file(file=cairo_pie_file)
return cairo_pie_file.getvalue()
@property
def program(self) -> StrippedProgram:
return self.metadata.program
def run_validity_checks(self):
self.metadata.run_validity_checks()
self.execution_resources.run_validity_checks()
assert isinstance(self.memory, MemoryDict), "Invalid type for memory."
self.run_memory_validity_checks()
assert sorted(f"{name}_builtin" for name in self.metadata.program.builtins) == sorted(
self.execution_resources.builtin_instance_counter.keys()
), "Builtin list mismatch in execution_resources."
assert isinstance(self.additional_data, dict) and all(
isinstance(name, str) and len(name) < 1000 for name in self.additional_data
), "Invalid additional_data."
def run_memory_validity_checks(self):
segment_sizes = self.metadata.segment_sizes()
for addr, value in self.memory.items():
assert is_valid_memory_addr(
addr=addr, segment_sizes=segment_sizes
), "Invalid memory cell address."
assert is_valid_memory_value(
value=value, segment_sizes=segment_sizes
), "Invalid memory cell value."
@classmethod
def verify_zip_format(cls, zf: zipfile.ZipFile):
"""
Checks that the given zip file contains the expected inner files, that the compression
type is ZIP_DEFLATED and that their size is not too big.
"""
# Check the compression algorithm.
assert all(
zip_info.compress_type == zipfile.ZIP_DEFLATED for zip_info in zf.filelist
), "Invalid compress type."
# Check that orig_filename == filename.
# Use "type: ignore" since mypy doesn't recognize ZipInfo.orig_filename.
assert all(
zip_info.orig_filename == zip_info.filename for zip_info in zf.filelist # type: ignore
), "File name mismatch."
# Make sure we have exactly the files we expect.
inner_files = {zip_info.filename: zip_info for zip_info in zf.filelist}
assert sorted(inner_files.keys() | cls.OPTIONAL_FILES) == sorted(
cls.ALL_FILES
), "Invalid list of inner files in the CairoPIE zip."
# Make sure the file sizes are reasonable.
for name, limit in (
(cls.METADATA_FILENAME, cls.MAX_SIZE),
(cls.MEMORY_FILENAME, cls.MAX_SIZE),
(cls.ADDITIONAL_DATA_FILENAME, cls.MAX_SIZE),
(cls.EXECUTION_RESOURCES_FILENAME, 10000),
(cls.VERSION_FILENAME, 10000),
):
size = inner_files[name].file_size if name in inner_files else 0
assert size < limit, f"Invalid file size {size} for {name}; limit is {limit}."
def get_segment(self, segment_info: SegmentInfo):
return self.memory.get_range(
RelocatableValue(segment_index=segment_info.index, offset=0), size=segment_info.size
)
def is_compatible_with(self, other: "CairoPie") -> bool:
"""
Checks equality between two CairoPies. Ignores .additional_data["pedersen_builtin"]
to avoid an issue where a stricter run checks more Pedersen addresses and results
in a different address list.
"""
with ignore_pedersen_data(self):
with ignore_pedersen_data(other):
return self == other
def diff(self, other: "CairoPie") -> str:
"""
Returns a short description of the diff between two CairoPies.
"""
res = ["CairoPie diff:"]
if self.metadata != other.metadata:
res.append(f" * metadata mismatch.")
if self.memory != other.memory:
res.append(f" * memory mismatch.")
if self.additional_data != other.additional_data:
res.append(f" * additional_data mismatch:")
for key in sorted(self.additional_data.keys() | other.additional_data.keys()):
if self.additional_data.get(key) != other.additional_data.get(key):
res.append(f" * {key} mismatch.")
if self.execution_resources != other.execution_resources:
res.append(
" * execution_resources mismatch: "
f"{self.execution_resources} != {other.execution_resources}."
)
if self.version != other.version:
res.append(f" * version mismatch: {self.version} != {other.version}.")
return "\n".join(res)
def verify_zip_file_prefix(fileobj):
"""
Verifies that the file starts with the zip file prefix.
"""
fileobj.seek(0)
# Make sure this is a zip file.
assert fileobj.read(2) in ["PK", b"PK"], "Invalid prefix for zip file."
@contextlib.contextmanager
def ignore_pedersen_data(pie: CairoPie):
"""
Context manager under which pie.additional_data["pedersen_builtin"] is set to None and
reverted to its original value (or removed if it didn't exist before) when the context
terminates.
"""
should_pop = "pedersen_builtin" not in pie.additional_data
original_pedersen_data, pie.additional_data["pedersen_builtin"] = (
pie.additional_data.get("pedersen_builtin"),
None,
)
try:
yield
finally:
if should_pop:
pie.additional_data.pop("pedersen_builtin")
else:
pie.additional_data["pedersen_builtin"] = original_pedersen_data