-
Notifications
You must be signed in to change notification settings - Fork 278
/
Copy pathvalidated_memory_dict.py
67 lines (52 loc) · 2.96 KB
/
validated_memory_dict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from typing import Callable, Dict, List, Set, Tuple
from starkware.cairo.lang.vm.memory_dict import MemoryDict
from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue
ValidationRule = Callable[["MemoryDict", RelocatableValue], Set[RelocatableValue]]
class ValidatedMemoryDict:
"""
A proxy to MemoryDict which validates memory values in specific segments upon writing to it.
Validation is done according to the validation rules.
In addition, all values that are written through it are taken modulo the program's prime.
"""
def __init__(self, memory: MemoryDict, prime: int):
self.__memory = memory
self.prime = prime
# validation_rules contains a mapping from a segment index to a list of functions
# (and a tuple of additional arguments) that may try to validate the value of memory cells
# in the segment (sometimes based on other memory cells).
self.__validation_rules: Dict[int, List[Tuple[ValidationRule, tuple]]] = {}
# A list of addresses which were already validated.
self.__validated_addresses: Set[RelocatableValue] = set()
def __getitem__(self, addr: MaybeRelocatable) -> MaybeRelocatable:
return self.__memory[addr]
def __setitem__(self, addr: MaybeRelocatable, value: MaybeRelocatable):
value %= self.prime
self.__memory[addr] = value
self._validate_memory_cell(addr, value)
def __getattr__(self, name: str):
if name in ["__deepcopy__", "__getstate__", "__setstate__"]:
raise AttributeError(f"ValidatedMemoryDict has no attribute named {name}.")
return getattr(self.__memory, name)
def __iter__(self):
return iter(self.__memory)
def __contains__(self, addr: MaybeRelocatable) -> bool:
return addr in self.__memory
def __len__(self):
return len(self.__memory)
def add_validation_rule(self, segment_index, rule: ValidationRule, *args):
"""
Adds a validation rule on the given segment, which will be called upon writing to this
segment (using setitem).
'rule' is a callback function that gets the current memory, a memory address within the
given segment and possibly some auxillary arguments, which are the given args.
The rule output is assumed to be the set of memory addresses validated by it.
"""
self.__validation_rules.setdefault(segment_index, []).append((rule, args))
def _validate_memory_cell(self, addr: MaybeRelocatable, value: MaybeRelocatable):
if isinstance(addr, RelocatableValue) and addr not in self.__validated_addresses:
for rule, args in self.__validation_rules.get(addr.segment_index, []):
validated_addresses = rule(self.__memory, addr, *args)
self.__validated_addresses |= validated_addresses
def validate_existing_memory(self):
for addr, value in self.__memory.items():
self._validate_memory_cell(addr, value)