Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] generic field expression vm chip #653

Merged
merged 6 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions circuits/ecc/src/field_expression/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ pub struct ExprBuilder {

// The equations to compute the newly introduced variables. For trace gen only.
pub computes: Vec<SymbolicExpr>,

pub output_indices: Vec<usize>,
}

impl ExprBuilder {
Expand All @@ -74,6 +76,7 @@ impl ExprBuilder {
carry_limbs: vec![],
constraints: vec![],
computes: vec![],
output_indices: vec![],
}
}

Expand Down
6 changes: 6 additions & 0 deletions circuits/ecc/src/field_expression/field_variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ impl FieldVariable {
builder.num_variables - 1
}

pub fn save_output(&mut self) {
let index = self.save();
let mut builder = self.builder.borrow_mut();
builder.output_indices.push(index);
}

pub fn canonical_limb_bits(&self) -> usize {
self.builder.borrow().limb_bits
}
Expand Down
257 changes: 37 additions & 220 deletions vm/src/intrinsics/ecc/sw/add_ne.rs
Original file line number Diff line number Diff line change
@@ -1,227 +1,44 @@
use std::{cell::RefCell, rc::Rc, sync::Arc};
use std::{cell::RefCell, rc::Rc};

use afs_primitives::{
bigint::check_carry_mod_to_zero::CheckCarryModToZeroSubAir,
var_range::{VariableRangeCheckerBus, VariableRangeCheckerChip},
SubAir, TraceSubRowGenerator,
bigint::check_carry_mod_to_zero::CheckCarryModToZeroSubAir, var_range::VariableRangeCheckerBus,
};
use afs_stark_backend::{interaction::InteractionBuilder, rap::BaseAirWithPublicValues};
use ax_ecc_primitives::field_expression::{ExprBuilder, FieldExpr, FieldExprCols};
use ax_ecc_primitives::field_expression::{ExprBuilder, FieldExpr};
use num_bigint_dig::BigUint;
use p3_air::BaseAir;
use p3_field::{AbstractField, Field, PrimeField32};

use super::super::{EcPoint, FIELD_ELEMENT_BITS};
use crate::{
arch::{
instructions::EccOpcode, AdapterAirContext, AdapterRuntimeContext, DynAdapterInterface,
DynArray, MinimalInstruction, Result, VmAdapterInterface, VmCoreAir, VmCoreChip,
},
system::program::Instruction,
utils::{biguint_to_limbs_vec, limbs_to_biguint},
};

#[derive(Clone)]
pub struct SwEcAddNeCoreAir {
pub expr: FieldExpr,
pub offset: usize,
}

impl SwEcAddNeCoreAir {
pub fn new(
modulus: BigUint, // The coordinate field.
num_limbs: usize,
limb_bits: usize,
range_bus: VariableRangeCheckerBus,
offset: usize,
) -> Self {
assert!(modulus.bits() <= num_limbs * limb_bits);
let subair = CheckCarryModToZeroSubAir::new(
modulus.clone(),
limb_bits,
range_bus.index,
range_bus.range_max_bits,
FIELD_ELEMENT_BITS,
);
let builder = ExprBuilder::new(modulus, limb_bits, num_limbs, range_bus.range_max_bits);
let builder = Rc::new(RefCell::new(builder));

let x1 = ExprBuilder::new_input(builder.clone());
let y1 = ExprBuilder::new_input(builder.clone());
let x2 = ExprBuilder::new_input(builder.clone());
let y2 = ExprBuilder::new_input(builder.clone());
let mut lambda = (y2 - y1.clone()) / (x2.clone() - x1.clone());
let mut x3 = lambda.square() - x1.clone() - x2;
x3.save();
let mut y3 = lambda * (x1 - x3.clone()) - y1;
y3.save();

let builder = builder.borrow().clone();
let expr = FieldExpr {
builder,
check_carry_mod_to_zero: subair,
range_bus,
};
Self { expr, offset }
}
}

impl<F: Field> BaseAir<F> for SwEcAddNeCoreAir {
fn width(&self) -> usize {
BaseAir::<F>::width(&self.expr)
}
}

impl<F: Field> BaseAirWithPublicValues<F> for SwEcAddNeCoreAir {}

impl<AB: InteractionBuilder, I> VmCoreAir<AB, I> for SwEcAddNeCoreAir
where
I: VmAdapterInterface<AB::Expr>,
AdapterAirContext<AB::Expr, I>:
From<AdapterAirContext<AB::Expr, DynAdapterInterface<AB::Expr>>>,
{
fn eval(
&self,
builder: &mut AB,
local: &[AB::Var],
_from_pc: AB::Var,
) -> AdapterAirContext<AB::Expr, I> {
assert_eq!(local.len(), BaseAir::<AB::F>::width(&self.expr));
self.expr.eval(builder, local);

let FieldExprCols {
is_valid,
inputs,
vars,
flags,
..
} = self.expr.load_vars(local);
assert_eq!(inputs.len(), 4);
assert_eq!(vars.len(), 3);
assert_eq!(flags.len(), 0);
let reads: Vec<AB::Expr> = inputs.concat().iter().map(|x| (*x).into()).collect();
let writes: Vec<AB::Expr> = vars[1..].concat().iter().map(|x| (*x).into()).collect();

let expected_opcode = EccOpcode::EC_ADD_NE as usize;
let expected_opcode = AB::Expr::from_canonical_usize(expected_opcode);

let instruction = MinimalInstruction {
is_valid: is_valid.into(),
opcode: expected_opcode + AB::Expr::from_canonical_usize(self.offset),
};

let ctx: AdapterAirContext<_, DynAdapterInterface<_>> = AdapterAirContext {
to_pc: None,
reads: reads.into(),
writes: writes.into(),
instruction: instruction.into(),
};
ctx.into()
}
}

pub struct SwEcAddNeCoreChip {
pub air: SwEcAddNeCoreAir,
pub range_checker: Arc<VariableRangeCheckerChip>,
}

impl SwEcAddNeCoreChip {
pub fn new(
modulus: BigUint,
num_limbs: usize,
limb_bits: usize,
range_checker: Arc<VariableRangeCheckerChip>,
offset: usize,
) -> Self {
let air = SwEcAddNeCoreAir::new(modulus, num_limbs, limb_bits, range_checker.bus(), offset);
Self { air, range_checker }
}
}

pub struct SwEcAddNeCoreRecord {
pub p1: EcPoint,
pub p2: EcPoint,
}

impl<F: PrimeField32, I> VmCoreChip<F, I> for SwEcAddNeCoreChip
where
I: VmAdapterInterface<F>,
I::Reads: Into<DynArray<F>>,
AdapterRuntimeContext<F, I>: From<AdapterRuntimeContext<F, DynAdapterInterface<F>>>,
{
type Record = SwEcAddNeCoreRecord;
type Air = SwEcAddNeCoreAir;

fn execute_instruction(
&self,
_instruction: &Instruction<F>,
_from_pc: u32,
reads: I::Reads,
) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
// Input: 2 EcPoint<Fp>, so total 4 field elements.
let field_element_limbs = self.air.expr.canonical_num_limbs();
let limb_bits = self.air.expr.canonical_limb_bits();
let data: DynArray<_> = reads.into();
let data = data.0;
assert_eq!(data.len(), 4 * field_element_limbs);
let data_u32: Vec<u32> = data.iter().map(|x| x.as_canonical_u32()).collect();

let x1 = limbs_to_biguint(&data_u32[..field_element_limbs], limb_bits);
let y1 = limbs_to_biguint(
&data_u32[field_element_limbs..2 * field_element_limbs],
limb_bits,
);
let x2 = limbs_to_biguint(
&data_u32[2 * field_element_limbs..3 * field_element_limbs],
limb_bits,
);
let y2 = limbs_to_biguint(
&data_u32[3 * field_element_limbs..4 * field_element_limbs],
limb_bits,
);

let vars = self
.air
.expr
.execute(vec![x1.clone(), y1.clone(), x2.clone(), y2.clone()], vec![]);
assert_eq!(vars.len(), 3); // lambda, x3, y3
let x3 = vars[1].clone();
let y3 = vars[2].clone();

let x3_limbs = biguint_to_limbs_vec(x3, limb_bits, field_element_limbs);
let y3_limbs = biguint_to_limbs_vec(y3, limb_bits, field_element_limbs);
let writes = [x3_limbs, y3_limbs]
.concat()
.into_iter()
.map(|x| F::from_canonical_u32(x))
.collect::<Vec<_>>();
let ctx = AdapterRuntimeContext::<_, DynAdapterInterface<_>>::without_pc(writes);

Ok((
ctx.into(),
SwEcAddNeCoreRecord {
p1: EcPoint { x: x1, y: y1 },
p2: EcPoint { x: x2, y: y2 },
},
))
}

fn get_opcode_name(&self, _opcode: usize) -> String {
"SwEcAddNe".to_string()
}

fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
self.air.expr.generate_subrow(
(
&self.range_checker,
vec![record.p1.x, record.p1.y, record.p2.x, record.p2.y],
vec![],
),
row_slice,
);
}

fn air(&self) -> &Self::Air {
&self.air
use super::super::FIELD_ELEMENT_BITS;

pub fn ec_add_ne_expr(
modulus: BigUint, // The coordinate field.
num_limbs: usize,
limb_bits: usize,
range_bus: VariableRangeCheckerBus,
) -> FieldExpr {
assert!(modulus.bits() <= num_limbs * limb_bits);
let subair = CheckCarryModToZeroSubAir::new(
modulus.clone(),
limb_bits,
range_bus.index,
range_bus.range_max_bits,
FIELD_ELEMENT_BITS,
);
let builder = ExprBuilder::new(modulus, limb_bits, num_limbs, range_bus.range_max_bits);
let builder = Rc::new(RefCell::new(builder));

let x1 = ExprBuilder::new_input(builder.clone());
let y1 = ExprBuilder::new_input(builder.clone());
let x2 = ExprBuilder::new_input(builder.clone());
let y2 = ExprBuilder::new_input(builder.clone());
let mut lambda = (y2 - y1.clone()) / (x2.clone() - x1.clone());
let mut x3 = lambda.square() - x1.clone() - x2;
x3.save_output();
let mut y3 = lambda * (x1 - x3.clone()) - y1;
y3.save_output();

let builder = builder.borrow().clone();
FieldExpr {
builder,
check_carry_mod_to_zero: subair,
range_bus,
}
}
Loading