-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Hints rust/python scaffolding (#373)
- Loading branch information
1 parent
5c1ee08
commit cd9a1ee
Showing
6 changed files
with
138 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# ruff: noqa: F403 | ||
from cairo_addons.hints.decorator import implementations, register_hint | ||
from cairo_addons.hints.dict import * | ||
from cairo_addons.hints.os import * | ||
|
||
__all__ = [ | ||
"register_hint", | ||
"implementations", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import ast | ||
import inspect | ||
|
||
implementations = {} | ||
|
||
|
||
def get_function_body(func) -> str: | ||
"""Extract just the body of a function as a string.""" | ||
source = inspect.getsource(func) | ||
|
||
# Parse the source into an AST | ||
tree = ast.parse(source) | ||
|
||
# Get the function definition node | ||
func_def = tree.body[0] | ||
|
||
# Split source into lines | ||
lines = source.splitlines() | ||
|
||
# Handle single-line functions (body on same line as def) | ||
if len(func_def.body) == 1 and isinstance(func_def.body[0], ast.Expr): | ||
body = str(func_def.body[0].value.value) # For docstrings | ||
# Handle single-line functions with return/assign/etc | ||
elif len(lines) <= func_def.body[0].lineno: | ||
body_lines = [lines[-1]] # Take last line as body | ||
indent = len(body_lines[0]) - len(body_lines[0].lstrip()) | ||
body = body_lines[0][indent:] | ||
else: | ||
# Multi-line function - get all non-empty lines after def | ||
body_lines = [ | ||
line for line in lines[func_def.body[0].lineno - 1 :] if line != "" | ||
] | ||
if body_lines: | ||
indent = len(body_lines[0]) - len(body_lines[0].lstrip()) | ||
body = "\n".join(line[indent:] for line in body_lines) | ||
else: | ||
body = "pass" # Empty function body | ||
|
||
return body | ||
|
||
|
||
def register_hint(wrapped_function): | ||
implementations[wrapped_function.__name__] = get_function_body(wrapped_function) | ||
|
||
return wrapped_function |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from cairo_addons.hints.decorator import register_hint | ||
from starkware.cairo.common.dict import DictManager | ||
from starkware.cairo.lang.vm.memory_dict import MemoryDict | ||
from starkware.cairo.lang.vm.memory_segments import MemorySegmentManager | ||
from starkware.cairo.lang.vm.relocatable import RelocatableValue | ||
from starkware.cairo.lang.vm.vm_consts import VmConsts | ||
|
||
|
||
@register_hint | ||
def dict_copy(dict_manager: DictManager, ids: VmConsts): | ||
from starkware.cairo.common.dict import DictTracker | ||
|
||
if ids.new_start.address_.segment_index in dict_manager.trackers: | ||
raise ValueError( | ||
f"Segment {ids.new_start.address_.segment_index} already exists in dict_manager.trackers" | ||
) | ||
|
||
data = dict_manager.trackers[ids.dict_start.address_.segment_index].data.copy() | ||
dict_manager.trackers[ids.new_start.address_.segment_index] = DictTracker( | ||
data=data, | ||
current_ptr=ids.new_end.address_, | ||
) | ||
|
||
|
||
@register_hint | ||
def dict_squash( | ||
dict_manager: DictManager, | ||
ids: VmConsts, | ||
segments: MemorySegmentManager, | ||
memory: MemoryDict, | ||
ap: RelocatableValue, | ||
): | ||
from starkware.cairo.common.dict import DictTracker | ||
|
||
data = dict_manager.get_dict(ids.dict_accesses_end).copy() | ||
base = segments.add() | ||
assert base.segment_index not in dict_manager.trackers | ||
dict_manager.trackers[base.segment_index] = DictTracker(data=data, current_ptr=base) | ||
memory[ap] = base |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from cairo_addons.hints.decorator import register_hint | ||
from starkware.cairo.common.dict import DictManager | ||
from starkware.cairo.lang.vm.memory_segments import MemorySegmentManager | ||
from starkware.cairo.lang.vm.vm_consts import VmConsts | ||
|
||
|
||
@register_hint | ||
def block( | ||
dict_manager: DictManager, | ||
segments: MemorySegmentManager, | ||
program_input: dict, | ||
ids: VmConsts, | ||
): | ||
from tests.utils.hints import gen_arg_pydantic | ||
|
||
ids.block = gen_arg_pydantic(dict_manager, segments, program_input["block"]) | ||
|
||
|
||
@register_hint | ||
def state( | ||
dict_manager: DictManager, | ||
segments: MemorySegmentManager, | ||
program_input: dict, | ||
ids: VmConsts, | ||
): | ||
from tests.utils.hints import gen_arg_pydantic | ||
|
||
ids.state = gen_arg_pydantic(dict_manager, segments, program_input["state"]) | ||
|
||
|
||
@register_hint | ||
def chain_id(ids: VmConsts): | ||
ids.chain_id = 1 | ||
|
||
|
||
@register_hint | ||
def block_hashes(segments: MemorySegmentManager, ids: VmConsts): | ||
import random | ||
|
||
ids.block_hashes = segments.gen_arg( | ||
[random.randint(0, 2**128 - 1) for _ in range(256 * 2)] | ||
) |