Skip to content

Commit

Permalink
chore: clean up for loops and breaks (#1234)
Browse files Browse the repository at this point in the history
* chore: remove unused poseidon2_hash

* chore: remove unused then_may_break and then_or_else_may_break

* chore: remove unused RangeBuilderWithBreaks

* chore: remove static_loop and disable_break options

* chore: remove unused BreakLoop

* chore: remove break support

* chore: update comment

* chore: change all instances of iter to use zip

* fix: lint

* chore: remove unused step_by

* chore: use zip for range

* chore: remove unused RangeBuilder, For

* chore: rename compile_zip to iter_zip

* fix: rename

* chore: rename ZippedPointerIteratorBuilder to IteratorBuilder
  • Loading branch information
yi-sun authored Jan 21, 2025
1 parent a94d0bb commit ab01b0c
Show file tree
Hide file tree
Showing 26 changed files with 583 additions and 1,255 deletions.
3 changes: 2 additions & 1 deletion crates/sdk/src/verifier/common/non_leaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ impl<C: Config> NonLeafVerifierVariables<C> {
let pvs = VmVerifierPvs::<Felt<C::F>>::uninit(builder);
let leaf_verifier_commit = array::from_fn(|_| builder.uninit());

builder.range(0, proofs.len()).for_each(|i, builder| {
builder.range(0, proofs.len()).for_each(|i_vec, builder| {
let i = i_vec[0];
let proof = builder.get(proofs, i);
assert_required_air_for_agg_vm_present(builder, &proof);
let proof_vm_pvs = self.verify_internal_or_leaf_verifier_proof(builder, &proof);
Expand Down
3 changes: 2 additions & 1 deletion crates/sdk/src/verifier/leaf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ impl LeafVmVerifierConfig {

builder.cycle_tracker_start("VerifyProofs");
let pvs = VmVerifierPvs::<Felt<F>>::uninit(&mut builder);
builder.range(0, proofs.len()).for_each(|i, builder| {
builder.range(0, proofs.len()).for_each(|i_vec, builder| {
let i = i_vec[0];
let proof = builder.get(&proofs, i);
assert_required_air_for_app_vm_present(builder, &proof);
StarkVerifier::verify::<DuplexChallengerVariable<C>>(
Expand Down
4 changes: 2 additions & 2 deletions crates/sdk/src/verifier/leaf/vars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ impl Hintable<C> for UserPublicValuesRootProof<F> {
fn read(builder: &mut Builder<C>) -> Self::HintVariable {
let len = builder.hint_var();
let sibling_hashes = builder.array(len);
builder.range(0, len).for_each(|i, builder| {
builder.range(0, len).for_each(|i_vec, builder| {
// FIXME: add hint support for slices.
let hash = array::from_fn(|_| builder.hint_felt());
builder.set_value(&sibling_hashes, i, hash);
builder.set_value(&sibling_hashes, i_vec[0], hash);
});
let public_values_commit = array::from_fn(|_| builder.hint_felt());
Self::HintVariable {
Expand Down
13 changes: 1 addition & 12 deletions extensions/native/compiler/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,4 @@ In **static programs**, only constant branches are allowed.
When both `start` and `end` of a loop are constant, the loop is a constant loop. The loop body will be unrolled. This
optimization saves 1 instruction per iteration.

In **static programs**, only constant loops are allowed.

## Break Support
**!!Attention!!**: Break support for constant loops is not perfect. It brings some restrictions which require
developers' awareness.:

- If you want to use `break` in a possibly constant loop, you need to use `.for_each_may_break` instead of `.for_each`.
- If you want to use `break` in a branch inside a loop, you need to use `.then_may_break`/`.then_or_else_may_break`
instead of `.for_each`/`.then_or_else`.
- Inside a **constant loop**, you cannot use a **non-constant branch** to break.


In **static programs**, only constant loops are allowed.
10 changes: 5 additions & 5 deletions extensions/native/compiler/derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,24 +169,24 @@ pub fn hintable_derive(input: TokenStream) -> TokenStream {
}
}

struct CompileZipArgs {
struct IterZipArgs {
builder: Expr,
args: Punctuated<Expr, Token![,]>,
}

impl Parse for CompileZipArgs {
impl Parse for IterZipArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let builder = input.parse()?;
let _: Token![,] = input.parse()?;
let args = Punctuated::parse_terminated(input)?;

Ok(CompileZipArgs { builder, args })
Ok(IterZipArgs { builder, args })
}
}

#[proc_macro]
pub fn compile_zip(input: TokenStream) -> TokenStream {
let CompileZipArgs { builder, args } = parse_macro_input!(input as CompileZipArgs);
pub fn iter_zip(input: TokenStream) -> TokenStream {
let IterZipArgs { builder, args } = parse_macro_input!(input as IterZipArgs);
let array_elements = args.iter().map(|arg| {
quote! {
Box::new(#arg.clone()) as Box<dyn ArrayLike<_>>
Expand Down
168 changes: 2 additions & 166 deletions extensions/native/compiler/src/asm/compiler.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use alloc::{collections::BTreeMap, vec};
use std::collections::BTreeSet;

use openvm_circuit::arch::instructions::instruction::DebugInfo;
use openvm_stark_backend::p3_field::{ExtensionField, Field, PrimeField32, TwoAdicField};
Expand Down Expand Up @@ -29,10 +28,6 @@ pub(crate) const STACK_TOP: i32 = HEAP_START_ADDRESS - 64;
// #[derive(Debug, Clone, Default)]
pub struct AsmCompiler<F, EF> {
basic_blocks: Vec<BasicBlock<F, EF>>,
break_label: Option<F>,
break_label_map: BTreeMap<F, F>,
break_counter: usize,
contains_break: BTreeSet<F>,
function_labels: BTreeMap<String, F>,
trap_label: F,
word_size: usize,
Expand Down Expand Up @@ -74,25 +69,12 @@ impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField> AsmCo
pub fn new(word_size: usize) -> Self {
Self {
basic_blocks: vec![BasicBlock::new()],
break_label: None,
break_label_map: BTreeMap::new(),
contains_break: BTreeSet::new(),
function_labels: BTreeMap::new(),
break_counter: 0,
trap_label: F::ONE,
word_size,
}
}

/// Creates a new break label.
pub fn new_break_label(&mut self) -> F {
let label = self.break_counter;
self.break_counter += 1;
let label = F::from_canonical_usize(label);
self.break_label = Some(label);
label
}

/// Builds the operations into assembly instructions.
pub fn build(&mut self, operations: TracedVec<DslIr<AsmConfig<F, EF>>>) {
if self.block_label().is_zero() {
Expand Down Expand Up @@ -365,22 +347,6 @@ impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField> AsmCo
);
}
}
DslIr::Break => {
let label = self.break_label.expect("No break label set");
let current_block = self.block_label();
self.contains_break.insert(current_block);
self.push(AsmInstruction::Break(label), debug_info);
}
DslIr::For(start, end, step_size, loop_var, block) => {
let for_compiler = ForCompiler {
compiler: self,
start,
end,
step_size,
loop_var,
};
for_compiler.for_each(move |_, builder| builder.build(block), debug_info);
}
DslIr::ZipFor(starts, end0, step_sizes, loop_vars, block) => {
let zip_for_compiler = ZipForCompiler {
compiler: self,
Expand Down Expand Up @@ -815,7 +781,8 @@ impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField> IfCom
}
}

// Zipped for loop -- loop extends over the first entry in starts and ends
// Zipped for loop -- loop extends over the first entry in starts and end0
// ATTENTION: starting with starts[0] > end0 will lead to undefined behavior.
pub struct ZipForCompiler<'a, F: Field, EF> {
compiler: &'a mut AsmCompiler<F, EF>,
starts: Vec<RVar<F>>,
Expand Down Expand Up @@ -856,8 +823,6 @@ impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField>
});

let loop_call_label = self.compiler.block_label();
let break_label = self.compiler.new_break_label();
self.compiler.break_label = Some(break_label);

self.compiler.basic_block();
let loop_label = self.compiler.block_label();
Expand Down Expand Up @@ -898,135 +863,6 @@ impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField>
.push_to_block(loop_call_label, instr, debug_info.clone());

self.compiler.basic_block();
let label = self.compiler.block_label();
self.compiler.break_label_map.insert(break_label, label);

for block in self.compiler.contains_break.iter() {
for instruction in self.compiler.basic_blocks[block.as_canonical_u32() as usize]
.0
.iter_mut()
{
if let AsmInstruction::Break(l) = instruction {
if *l == break_label {
*instruction = AsmInstruction::j(label);
}
}
}
}
}
}

/// A builder for a for loop.
///
/// SAFETY: Starting with end < start will lead to undefined behavior.
pub struct ForCompiler<'a, F: Field, EF> {
compiler: &'a mut AsmCompiler<F, EF>,
start: RVar<F>,
end: RVar<F>,
step_size: F,
loop_var: Var<F>,
}

impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField> ForCompiler<'_, F, EF> {
pub(super) fn for_each(
mut self,
f: impl FnOnce(Var<F>, &mut AsmCompiler<F, EF>),
debug_info: Option<DebugInfo>,
) {
// The function block structure:
// - Setting the loop range
// - Executing the loop body and incrementing the loop variable
// - the loop condition

// Set the loop variable to the start of the range.
self.set_loop_var(debug_info.clone());

// Save the label of the for loop call.
let loop_call_label = self.compiler.block_label();

// Initialize a break label for this loop.
let break_label = self.compiler.new_break_label();
self.compiler.break_label = Some(break_label);

// A basic block for the loop body
self.compiler.basic_block();

// Save the loop body label for the loop condition.
let loop_label = self.compiler.block_label();

// The loop body.
f(self.loop_var, self.compiler);

// Increment the loop variable.
self.compiler.push(
AsmInstruction::AddFI(self.loop_var.fp(), self.loop_var.fp(), self.step_size),
debug_info.clone(),
);

// Add a basic block for the loop condition.
self.compiler.basic_block();

// Jump to loop body if the loop condition still holds.
self.jump_to_loop_body(loop_label, debug_info.clone());

// Add a jump instruction to the loop condition in the loop call block.
let label = self.compiler.block_label();
let instr = AsmInstruction::j(label);
self.compiler
.push_to_block(loop_call_label, instr, debug_info.clone());

// Initialize the after loop block.
self.compiler.basic_block();

// Resolve the break label.
let label = self.compiler.block_label();
self.compiler.break_label_map.insert(break_label, label);

// Replace the break instruction with a jump to the after loop block.
for block in self.compiler.contains_break.iter() {
for instruction in self.compiler.basic_blocks[block.as_canonical_u32() as usize]
.0
.iter_mut()
{
if let AsmInstruction::Break(l) = instruction {
if *l == break_label {
*instruction = AsmInstruction::j(label);
}
}
}
}

// self.compiler.contains_break.clear();
}

fn set_loop_var(&mut self, debug_info: Option<DebugInfo>) {
match self.start {
RVar::Const(start) => {
self.compiler.push(
AsmInstruction::ImmF(self.loop_var.fp(), start),
debug_info.clone(),
);
}
RVar::Val(var) => {
self.compiler.push(
AsmInstruction::CopyF(self.loop_var.fp(), var.fp()),
debug_info.clone(),
);
}
}
}

fn jump_to_loop_body(&mut self, loop_label: F, debug_info: Option<DebugInfo>) {
match self.end {
RVar::Const(end) => {
let instr = AsmInstruction::BneI(loop_label, self.loop_var.fp(), end);
self.compiler.push(instr, debug_info.clone());
}
RVar::Val(end) => {
let instr = AsmInstruction::Bne(loop_label, self.loop_var.fp(), end.fp());
self.compiler.push(instr, debug_info.clone());
}
}
}
}

Expand Down
4 changes: 0 additions & 4 deletions extensions/native/compiler/src/asm/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,6 @@ pub enum AsmInstruction<F, EF> {
/// Halt.
Halt,

/// Break(label)
Break(F),

/// Perform a Poseidon2 permutation on state starting at address `lhs`
/// and store new state at `rhs`.
/// (a, b) are pointers to (lhs, rhs).
Expand Down Expand Up @@ -159,7 +156,6 @@ impl<F: PrimeField32, EF: ExtensionField<F>> AsmInstruction<F, EF> {

pub fn fmt(&self, labels: &BTreeMap<F, String>, f: &mut fmt::Formatter) -> fmt::Result {
match self {
AsmInstruction::Break(_) => panic!("Unresolved break instruction"),
AsmInstruction::LoadFI(dst, src, var_index, size, offset) => {
write!(
f,
Expand Down
1 change: 0 additions & 1 deletion extensions/native/compiler/src/conversion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ fn convert_instruction<F: PrimeField32, EF: ExtensionField<F>>(
options: &CompilerOptions,
) -> Program<F> {
let instructions = match instruction {
AsmInstruction::Break(_) => panic!("Unresolved break instruction"),
AsmInstruction::LoadFI(dst, src, index, size, offset) => vec![
// mem[dst] <- mem[mem[src] + index * size + offset]
inst(
Expand Down
Loading

0 comments on commit ab01b0c

Please sign in to comment.