-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add compile script and diff test cairo, circuit, compiled and ref
- Loading branch information
1 parent
002a7b6
commit e699f12
Showing
8 changed files
with
677 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from pathlib import Path | ||
|
||
import click | ||
from jinja2 import Environment, FileSystemLoader | ||
from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME | ||
from starkware.cairo.lang.compiler.identifier_definition import ( | ||
FunctionDefinition, | ||
StructDefinition, | ||
) | ||
|
||
from cairo_addons.compiler import cairo_compile | ||
from cairo_ec.compiler import circuit_compile | ||
|
||
|
||
class IntParamType(click.ParamType): | ||
name = "integer" | ||
|
||
def convert(self, value, param, ctx): | ||
try: | ||
if isinstance(value, int): | ||
return value | ||
if value.startswith("0x"): | ||
return int(value, 16) | ||
return int(value, 10) | ||
except ValueError: | ||
self.fail(f"{value!r} is not a valid integer", param, ctx) | ||
|
||
|
||
INT = IntParamType() | ||
|
||
|
||
def format_return_value(i: int) -> str: | ||
"""Format a return value for the template.""" | ||
return f"cast(range_check96_ptr - {4 * (i + 1)}, UInt384*)" | ||
|
||
|
||
def setup_jinja_env(): | ||
"""Set up the Jinja environment with the templates directory.""" | ||
templates_dir = Path(__file__).parent / "templates" | ||
templates_dir.mkdir(parents=True, exist_ok=True) | ||
env = Environment(loader=FileSystemLoader(templates_dir)) | ||
env.filters["format_return_value"] = format_return_value | ||
return env | ||
|
||
|
||
@click.command() | ||
@click.argument( | ||
"file_path", | ||
type=click.Path(exists=True, dir_okay=False, path_type=Path), | ||
required=False, | ||
) | ||
@click.option( | ||
"--file_path", | ||
"-f", | ||
type=click.Path(exists=True, dir_okay=False, path_type=Path), | ||
help="Path to the Cairo source file_path", | ||
) | ||
@click.option( | ||
"--prime", | ||
"-p", | ||
type=INT, | ||
default=DEFAULT_PRIME, | ||
help="Prime number to use (can be decimal like 123 or hex like 0x7b)", | ||
) | ||
def main(file_path: Path | None, prime: int): | ||
"""Compile a Cairo file_path and extract its circuits.""" | ||
if file_path is None: | ||
raise click.UsageError("File path is required (either as argument or with -f)") | ||
|
||
click.echo(f"Processing {file_path} with prime 0x{prime:x}") | ||
|
||
# Set up Jinja environment | ||
env = setup_jinja_env() | ||
header_template = env.get_template("header.cairo.j2") | ||
circuit_template = env.get_template("circuit.cairo.j2") | ||
|
||
# Compile the Cairo file | ||
program = cairo_compile(file_path, proof_mode=False, prime=prime) | ||
functions = [ | ||
k.path[-1] | ||
for k, v in program.identifiers.as_dict().items() | ||
if isinstance(v, FunctionDefinition) | ||
] | ||
if not functions: | ||
raise click.UsageError("No functions found in the file") | ||
|
||
# Generate output code | ||
output_parts = [header_template.render()] | ||
|
||
# Process each function | ||
for function in functions: | ||
circuit = circuit_compile(program, function) | ||
click.echo(f"Circuit {function}: {circuit}") | ||
|
||
# Render template with all necessary data | ||
circuit_code = circuit_template.render( | ||
name=function, | ||
args_struct=program.get_identifier(f"{function}.Args", StructDefinition), | ||
return_data_size=circuit["return_data_size"], | ||
circuit=circuit, | ||
) | ||
output_parts.append(circuit_code) | ||
|
||
# Write all circuits to output file | ||
output_path = file_path.parent / f"{file_path.stem}_compiled.cairo" | ||
output_path.write_text("\n\n".join(output_parts)) | ||
click.echo(f"Generated circuit file: {output_path}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
{%- set return_values = ["UInt384*"] * (return_data_size // 4) %} | ||
{%- set return_type = "" if not return_values else | ||
"-> " + (("(" if return_values|length > 1 else "") + | ||
return_values|join(", ") + | ||
(")" if return_values|length > 1 else "")) %} | ||
|
||
{%- set return_data = "()" %} | ||
{%- if return_values %} | ||
{%- set return_exprs = [] %} | ||
{%- for i in range(return_data_size // 4)|reverse %} | ||
{%- set offset = 4 * (i + 1) %} | ||
{%- set _ = return_exprs.append("cast(range_check96_ptr - " ~ offset ~ ", UInt384*)") %} | ||
{%- endfor %} | ||
{%- set return_data = ("(" if return_values|length > 1 else "") ~ return_exprs|join(", ") ~ (")" if return_values|length > 1 else "") %} | ||
{%- endif %} | ||
|
||
func {{name}}{range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: ModBuiltin*}( | ||
{%- for member_name in args_struct.members -%} | ||
{{ member_name }}: UInt384*{{ ", " if not loop.last else "" }} | ||
{%- endfor -%}, p: UInt384* | ||
) {{return_type}} { | ||
let (_, pc) = get_fp_and_pc(); | ||
|
||
pc_label: | ||
let add_offsets_ptr = pc + (add_offsets - pc_label); | ||
let mul_offsets_ptr = pc + (mul_offsets - pc_label); | ||
|
||
{%- for member_name in args_struct.members %} | ||
assert [range_check96_ptr + {{loop.index0 * 4}}] = {{member_name}}.d0; | ||
assert [range_check96_ptr + {{loop.index0 * 4 + 1}}] = {{member_name}}.d1; | ||
assert [range_check96_ptr + {{loop.index0 * 4 + 2}}] = {{member_name}}.d2; | ||
assert [range_check96_ptr + {{loop.index0 * 4 + 3}}] = {{member_name}}.d3; | ||
{% endfor %} | ||
|
||
{%- if circuit.add_mod_n > 0 %} | ||
assert add_mod_ptr[0] = ModBuiltin( | ||
p=[p], | ||
values_ptr=cast(range_check96_ptr, UInt384*), | ||
offsets_ptr=add_offsets_ptr, | ||
n={{circuit.add_mod_n}}, | ||
); | ||
{%- endif %} | ||
|
||
{%- if circuit.mul_mod_n > 0 %} | ||
assert mul_mod_ptr[0] = ModBuiltin( | ||
p=[p], | ||
values_ptr=cast(range_check96_ptr, UInt384*), | ||
offsets_ptr=mul_offsets_ptr, | ||
n={{circuit.mul_mod_n}}, | ||
); | ||
{%- endif %} | ||
|
||
%{ | ||
from starkware.cairo.lang.builtins.modulo.mod_builtin_runner import ModBuiltinRunner | ||
{% if circuit.add_mod_n > 0 %} | ||
assert builtin_runners["add_mod_builtin"].instance_def.batch_size == 1 | ||
{%- endif -%} | ||
{% if circuit.mul_mod_n > 0 %} | ||
assert builtin_runners["mul_mod_builtin"].instance_def.batch_size == 1 | ||
{%- endif %} | ||
|
||
ModBuiltinRunner.fill_memory( | ||
memory=memory, | ||
add_mod={{ "(ids.add_mod_ptr.address_, builtin_runners['add_mod_builtin'], " + circuit.add_mod_n|string + ")" if circuit.add_mod_n > 0 else "None" }}, | ||
mul_mod={{ "(ids.mul_mod_ptr.address_, builtin_runners['mul_mod_builtin'], " + circuit.mul_mod_n|string + ")" if circuit.mul_mod_n > 0 else "None" }}, | ||
) | ||
%} | ||
|
||
let range_check96_ptr = range_check96_ptr + {{circuit.total_offset}}; | ||
{% if circuit.add_mod_n > 0 %} | ||
let add_mod_ptr = add_mod_ptr + ModBuiltin.SIZE * {{circuit.add_mod_n}}; | ||
{% endif %} | ||
{%- if circuit.mul_mod_n > 0 -%} | ||
let mul_mod_ptr = mul_mod_ptr + ModBuiltin.SIZE * {{circuit.mul_mod_n}}; | ||
{% endif %} | ||
return {{return_data}}; | ||
|
||
add_offsets: | ||
{%- for offset in circuit.add_mod_offsets_ptr %} | ||
dw {{offset}}; | ||
{%- endfor %} | ||
|
||
mul_offsets: | ||
{%- for offset in circuit.mul_mod_offsets_ptr %} | ||
dw {{offset}}; | ||
{%- endfor %} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from starkware.cairo.common.cairo_builtins import UInt384, ModBuiltin | ||
from starkware.cairo.lang.compiler.lib.registers import get_fp_and_pc |
Oops, something went wrong.