Skip to content

Commit

Permalink
Hints rust/python scaffolding (#373)
Browse files Browse the repository at this point in the history
  • Loading branch information
ClementWalter authored Jan 8, 2025
1 parent 5c1ee08 commit cd9a1ee
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 68 deletions.
70 changes: 2 additions & 68 deletions cairo/src/utils/compiler.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,13 @@
from pathlib import Path

from cairo_addons.hints import implementations
from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME
from starkware.cairo.lang.compiler.cairo_compile import compile_cairo, get_module_reader
from starkware.cairo.lang.compiler.preprocessor.default_pass_manager import (
default_pass_manager,
)
from starkware.cairo.lang.compiler.program import CairoHint

from src.utils.constants import CHAIN_ID

dict_manager = """
if '__dict_manager' not in globals():
from starkware.cairo.common.dict import DictManager
__dict_manager = DictManager()
"""

dict_copy = """
from starkware.cairo.common.dict import DictTracker
if ids.new_start.address_.segment_index in __dict_manager.trackers:
raise ValueError(f"Segment {ids.new_start.address_.segment_index} already exists in __dict_manager.trackers")
data = __dict_manager.trackers[ids.dict_start.address_.segment_index].data.copy()
__dict_manager.trackers[ids.new_start.address_.segment_index] = DictTracker(
data=data,
current_ptr=ids.new_end.address_,
)
"""

dict_squash = """
from starkware.cairo.common.dict import DictTracker
data = __dict_manager.get_dict(ids.dict_accesses_end).copy()
base = segments.add()
assert base.segment_index not in __dict_manager.trackers
__dict_manager.trackers[base.segment_index] = DictTracker(
data=data, current_ptr=base
)
memory[ap] = base
"""

block = f"""
{dict_manager}
from tests.utils.hints import gen_arg_pydantic
ids.block = gen_arg_pydantic(__dict_manager, segments, program_input["block"])
"""

state = f"""
{dict_manager}
from tests.utils.hints import gen_arg_pydantic
ids.state = gen_arg_pydantic(__dict_manager, segments, program_input["state"])
"""

chain_id = f"""
ids.chain_id = {CHAIN_ID}
"""

block_hashes = """
import random
ids.block_hashes = segments.gen_arg([random.randint(0, 2**128 - 1) for _ in range(256 * 2)])
"""

hints = {
"dict_manager": dict_manager,
"dict_copy": dict_copy,
"dict_squash": dict_squash,
"block": block,
"state": state,
"chain_id": chain_id,
"block_hashes": block_hashes,
}


def implement_hints(program):
return {
Expand All @@ -82,7 +16,7 @@ def implement_hints(program):
CairoHint(
accessible_scopes=hint_.accessible_scopes,
flow_tracking_data=hint_.flow_tracking_data,
code=hints.get(hint_.code, hint_.code),
code=implementations.get(hint_.code, hint_.code),
)
)
for hint_ in v
Expand Down
1 change: 1 addition & 0 deletions cairo/tests/fixtures/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def _factory_py(entrypoint, *args, **kwargs):
hint_locals={
"program_input": kwargs,
"__dict_manager": dict_manager,
"dict_manager": dict_manager,
"gen_arg": gen_arg,
"serde": serde,
"oracle": oracle(cairo_program, serde, main_path, gen_arg),
Expand Down
9 changes: 9 additions & 0 deletions python/cairo-addons/src/cairo_addons/hints/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# ruff: noqa: F403
from cairo_addons.hints.decorator import implementations, register_hint
from cairo_addons.hints.dict import *
from cairo_addons.hints.os import *

__all__ = [
"register_hint",
"implementations",
]
45 changes: 45 additions & 0 deletions python/cairo-addons/src/cairo_addons/hints/decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import ast
import inspect

implementations = {}


def get_function_body(func) -> str:
"""Extract just the body of a function as a string."""
source = inspect.getsource(func)

# Parse the source into an AST
tree = ast.parse(source)

# Get the function definition node
func_def = tree.body[0]

# Split source into lines
lines = source.splitlines()

# Handle single-line functions (body on same line as def)
if len(func_def.body) == 1 and isinstance(func_def.body[0], ast.Expr):
body = str(func_def.body[0].value.value) # For docstrings
# Handle single-line functions with return/assign/etc
elif len(lines) <= func_def.body[0].lineno:
body_lines = [lines[-1]] # Take last line as body
indent = len(body_lines[0]) - len(body_lines[0].lstrip())
body = body_lines[0][indent:]
else:
# Multi-line function - get all non-empty lines after def
body_lines = [
line for line in lines[func_def.body[0].lineno - 1 :] if line != ""
]
if body_lines:
indent = len(body_lines[0]) - len(body_lines[0].lstrip())
body = "\n".join(line[indent:] for line in body_lines)
else:
body = "pass" # Empty function body

return body


def register_hint(wrapped_function):
implementations[wrapped_function.__name__] = get_function_body(wrapped_function)

return wrapped_function
39 changes: 39 additions & 0 deletions python/cairo-addons/src/cairo_addons/hints/dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from cairo_addons.hints.decorator import register_hint
from starkware.cairo.common.dict import DictManager
from starkware.cairo.lang.vm.memory_dict import MemoryDict
from starkware.cairo.lang.vm.memory_segments import MemorySegmentManager
from starkware.cairo.lang.vm.relocatable import RelocatableValue
from starkware.cairo.lang.vm.vm_consts import VmConsts


@register_hint
def dict_copy(dict_manager: DictManager, ids: VmConsts):
from starkware.cairo.common.dict import DictTracker

if ids.new_start.address_.segment_index in dict_manager.trackers:
raise ValueError(
f"Segment {ids.new_start.address_.segment_index} already exists in dict_manager.trackers"
)

data = dict_manager.trackers[ids.dict_start.address_.segment_index].data.copy()
dict_manager.trackers[ids.new_start.address_.segment_index] = DictTracker(
data=data,
current_ptr=ids.new_end.address_,
)


@register_hint
def dict_squash(
dict_manager: DictManager,
ids: VmConsts,
segments: MemorySegmentManager,
memory: MemoryDict,
ap: RelocatableValue,
):
from starkware.cairo.common.dict import DictTracker

data = dict_manager.get_dict(ids.dict_accesses_end).copy()
base = segments.add()
assert base.segment_index not in dict_manager.trackers
dict_manager.trackers[base.segment_index] = DictTracker(data=data, current_ptr=base)
memory[ap] = base
42 changes: 42 additions & 0 deletions python/cairo-addons/src/cairo_addons/hints/os.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from cairo_addons.hints.decorator import register_hint
from starkware.cairo.common.dict import DictManager
from starkware.cairo.lang.vm.memory_segments import MemorySegmentManager
from starkware.cairo.lang.vm.vm_consts import VmConsts


@register_hint
def block(
dict_manager: DictManager,
segments: MemorySegmentManager,
program_input: dict,
ids: VmConsts,
):
from tests.utils.hints import gen_arg_pydantic

ids.block = gen_arg_pydantic(dict_manager, segments, program_input["block"])


@register_hint
def state(
dict_manager: DictManager,
segments: MemorySegmentManager,
program_input: dict,
ids: VmConsts,
):
from tests.utils.hints import gen_arg_pydantic

ids.state = gen_arg_pydantic(dict_manager, segments, program_input["state"])


@register_hint
def chain_id(ids: VmConsts):
ids.chain_id = 1


@register_hint
def block_hashes(segments: MemorySegmentManager, ids: VmConsts):
import random

ids.block_hashes = segments.gen_arg(
[random.randint(0, 2**128 - 1) for _ in range(256 * 2)]
)

0 comments on commit cd9a1ee

Please sign in to comment.