Skip to content

Commit

Permalink
Add compile script and diff test cairo, circuit, compiled and ref
Browse files Browse the repository at this point in the history
  • Loading branch information
ClementWalter committed Jan 31, 2025
1 parent 002a7b6 commit e699f12
Show file tree
Hide file tree
Showing 8 changed files with 677 additions and 66 deletions.
6 changes: 6 additions & 0 deletions python/cairo-ec/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ dependencies = [
"cairo-core",
"cairo-lang>=0.13.3",
"maturin>=1.8.1",
"click>=8.1.7",
]

[project.scripts]
compile_circuit = "scripts.compile_circuit:main"

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
Expand All @@ -19,6 +24,7 @@ build-backend = "hatchling.build"
include = [
"src/**/*.cairo", # Include all .cairo files in src directory
]
packages = ["src/cairo_ec", "scripts"]

[tool.hatch.build]
artifacts = [
Expand Down
111 changes: 111 additions & 0 deletions python/cairo-ec/scripts/compile_circuit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from pathlib import Path

import click
from jinja2 import Environment, FileSystemLoader
from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME
from starkware.cairo.lang.compiler.identifier_definition import (
FunctionDefinition,
StructDefinition,
)

from cairo_addons.compiler import cairo_compile
from cairo_ec.compiler import circuit_compile


class IntParamType(click.ParamType):
name = "integer"

def convert(self, value, param, ctx):
try:
if isinstance(value, int):
return value
if value.startswith("0x"):
return int(value, 16)
return int(value, 10)
except ValueError:
self.fail(f"{value!r} is not a valid integer", param, ctx)


INT = IntParamType()


def format_return_value(i: int) -> str:
"""Format a return value for the template."""
return f"cast(range_check96_ptr - {4 * (i + 1)}, UInt384*)"


def setup_jinja_env():
"""Set up the Jinja environment with the templates directory."""
templates_dir = Path(__file__).parent / "templates"
templates_dir.mkdir(parents=True, exist_ok=True)
env = Environment(loader=FileSystemLoader(templates_dir))
env.filters["format_return_value"] = format_return_value
return env


@click.command()
@click.argument(
"file_path",
type=click.Path(exists=True, dir_okay=False, path_type=Path),
required=False,
)
@click.option(
"--file_path",
"-f",
type=click.Path(exists=True, dir_okay=False, path_type=Path),
help="Path to the Cairo source file_path",
)
@click.option(
"--prime",
"-p",
type=INT,
default=DEFAULT_PRIME,
help="Prime number to use (can be decimal like 123 or hex like 0x7b)",
)
def main(file_path: Path | None, prime: int):
"""Compile a Cairo file_path and extract its circuits."""
if file_path is None:
raise click.UsageError("File path is required (either as argument or with -f)")

click.echo(f"Processing {file_path} with prime 0x{prime:x}")

# Set up Jinja environment
env = setup_jinja_env()
header_template = env.get_template("header.cairo.j2")
circuit_template = env.get_template("circuit.cairo.j2")

# Compile the Cairo file
program = cairo_compile(file_path, proof_mode=False, prime=prime)
functions = [
k.path[-1]
for k, v in program.identifiers.as_dict().items()
if isinstance(v, FunctionDefinition)
]
if not functions:
raise click.UsageError("No functions found in the file")

# Generate output code
output_parts = [header_template.render()]

# Process each function
for function in functions:
circuit = circuit_compile(program, function)
click.echo(f"Circuit {function}: {circuit}")

# Render template with all necessary data
circuit_code = circuit_template.render(
name=function,
args_struct=program.get_identifier(f"{function}.Args", StructDefinition),
return_data_size=circuit["return_data_size"],
circuit=circuit,
)
output_parts.append(circuit_code)

# Write all circuits to output file
output_path = file_path.parent / f"{file_path.stem}_compiled.cairo"
output_path.write_text("\n\n".join(output_parts))
click.echo(f"Generated circuit file: {output_path}")


if __name__ == "__main__":
main()
87 changes: 87 additions & 0 deletions python/cairo-ec/scripts/templates/circuit.cairo.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
{%- set return_values = ["UInt384*"] * (return_data_size // 4) %}
{%- set return_type = "" if not return_values else
"-> " + (("(" if return_values|length > 1 else "") +
return_values|join(", ") +
(")" if return_values|length > 1 else "")) %}

{%- set return_data = "()" %}
{%- if return_values %}
{%- set return_exprs = [] %}
{%- for i in range(return_data_size // 4)|reverse %}
{%- set offset = 4 * (i + 1) %}
{%- set _ = return_exprs.append("cast(range_check96_ptr - " ~ offset ~ ", UInt384*)") %}
{%- endfor %}
{%- set return_data = ("(" if return_values|length > 1 else "") ~ return_exprs|join(", ") ~ (")" if return_values|length > 1 else "") %}
{%- endif %}

func {{name}}{range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: ModBuiltin*}(
{%- for member_name in args_struct.members -%}
{{ member_name }}: UInt384*{{ ", " if not loop.last else "" }}
{%- endfor -%}, p: UInt384*
) {{return_type}} {
let (_, pc) = get_fp_and_pc();

pc_label:
let add_offsets_ptr = pc + (add_offsets - pc_label);
let mul_offsets_ptr = pc + (mul_offsets - pc_label);

{%- for member_name in args_struct.members %}
assert [range_check96_ptr + {{loop.index0 * 4}}] = {{member_name}}.d0;
assert [range_check96_ptr + {{loop.index0 * 4 + 1}}] = {{member_name}}.d1;
assert [range_check96_ptr + {{loop.index0 * 4 + 2}}] = {{member_name}}.d2;
assert [range_check96_ptr + {{loop.index0 * 4 + 3}}] = {{member_name}}.d3;
{% endfor %}

{%- if circuit.add_mod_n > 0 %}
assert add_mod_ptr[0] = ModBuiltin(
p=[p],
values_ptr=cast(range_check96_ptr, UInt384*),
offsets_ptr=add_offsets_ptr,
n={{circuit.add_mod_n}},
);
{%- endif %}

{%- if circuit.mul_mod_n > 0 %}
assert mul_mod_ptr[0] = ModBuiltin(
p=[p],
values_ptr=cast(range_check96_ptr, UInt384*),
offsets_ptr=mul_offsets_ptr,
n={{circuit.mul_mod_n}},
);
{%- endif %}

%{
from starkware.cairo.lang.builtins.modulo.mod_builtin_runner import ModBuiltinRunner
{% if circuit.add_mod_n > 0 %}
assert builtin_runners["add_mod_builtin"].instance_def.batch_size == 1
{%- endif -%}
{% if circuit.mul_mod_n > 0 %}
assert builtin_runners["mul_mod_builtin"].instance_def.batch_size == 1
{%- endif %}

ModBuiltinRunner.fill_memory(
memory=memory,
add_mod={{ "(ids.add_mod_ptr.address_, builtin_runners['add_mod_builtin'], " + circuit.add_mod_n|string + ")" if circuit.add_mod_n > 0 else "None" }},
mul_mod={{ "(ids.mul_mod_ptr.address_, builtin_runners['mul_mod_builtin'], " + circuit.mul_mod_n|string + ")" if circuit.mul_mod_n > 0 else "None" }},
)
%}

let range_check96_ptr = range_check96_ptr + {{circuit.total_offset}};
{% if circuit.add_mod_n > 0 %}
let add_mod_ptr = add_mod_ptr + ModBuiltin.SIZE * {{circuit.add_mod_n}};
{% endif %}
{%- if circuit.mul_mod_n > 0 -%}
let mul_mod_ptr = mul_mod_ptr + ModBuiltin.SIZE * {{circuit.mul_mod_n}};
{% endif %}
return {{return_data}};

add_offsets:
{%- for offset in circuit.add_mod_offsets_ptr %}
dw {{offset}};
{%- endfor %}

mul_offsets:
{%- for offset in circuit.mul_mod_offsets_ptr %}
dw {{offset}};
{%- endfor %}
}
2 changes: 2 additions & 0 deletions python/cairo-ec/scripts/templates/header.cairo.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from starkware.cairo.common.cairo_builtins import UInt384, ModBuiltin
from starkware.cairo.lang.compiler.lib.registers import get_fp_and_pc
Loading

0 comments on commit e699f12

Please sign in to comment.