Skip to content

Commit

Permalink
Reuse Poseidon2 Parameters in Halo2 Compiler
Browse files Browse the repository at this point in the history
  • Loading branch information
nyunyunyunyu committed Dec 9, 2024
1 parent a794bf6 commit bc38e23
Showing 1 changed file with 52 additions and 49 deletions.
101 changes: 52 additions & 49 deletions extensions/native/compiler/src/constraints/halo2/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
fmt::Debug,
marker::PhantomData,
panic::{catch_unwind, AssertUnwindSafe},
sync::Arc,
sync::{Arc, LazyLock},
};

use ax_stark_backend::p3_field::{ExtensionField, PrimeField};
Expand Down Expand Up @@ -31,6 +31,54 @@ use crate::{
ir::{Config, DslIr, TracedVec, Witness},
};

const POSEIDON2_T: usize = 3;
static POSEIDON2_PARAMS: LazyLock<Poseidon2Params<Fr, POSEIDON2_T>> = LazyLock::new(|| {
use zkhash::{
ark_ff::{BigInteger, PrimeField as _},
fields::bn256::FpBN256 as ark_FpBN256,
poseidon2::poseidon2_instance_bn256::{MAT_DIAG3_M_1, RC3},
};

fn convert_fr(input: ark_FpBN256) -> Fr {
Fr::from_bytes_le(&input.into_bigint().to_bytes_le())
}
const T: usize = 3;
let rounds_f = 8;
let rounds_p = 56;
let mut round_constants: Vec<[Fr; T]> = RC3
.iter()
.map(|vec| {
vec.iter()
.cloned()
.map(convert_fr)
.collect::<Vec<_>>()
.try_into()
.unwrap()
})
.collect();

let rounds_f_beginning = rounds_f / 2;
let p_end = rounds_f_beginning + rounds_p;
let internal_round_constants = round_constants
.drain(rounds_f_beginning..p_end)
.map(|vec| vec[0])
.collect::<Vec<_>>();
let external_round_constants = round_constants;
Poseidon2Params {
rounds_f,
rounds_p,
mat_internal_diag_m_1: MAT_DIAG3_M_1
.iter()
.copied()
.map(convert_fr)
.collect_vec()
.try_into()
.unwrap(),
external_rc: external_round_constants,
internal_rc: internal_round_constants,
}
});

/// The backend for the Halo2 constraint compiler.
#[derive(Debug, Clone)]
pub struct Halo2ConstraintCompiler<C: Config> {
Expand Down Expand Up @@ -278,55 +326,10 @@ impl<C: Config + Debug> Halo2ConstraintCompiler<C> {
}
}
DslIr::CircuitPoseidon2Permute(state_vars) => {
use zkhash::{
ark_ff::{BigInteger, PrimeField as _},
fields::bn256::FpBN256 as ark_FpBN256,
poseidon2::poseidon2_instance_bn256::{MAT_DIAG3_M_1, RC3},
};

fn convert_fr(input: ark_FpBN256) -> Fr {
Fr::from_bytes_le(&input.into_bigint().to_bytes_le())
}
const T: usize = 3;
let rounds_f = 8;
let rounds_p = 56;
let mut round_constants: Vec<[Fr; T]> = RC3
.iter()
.map(|vec| {
vec.iter()
.cloned()
.map(convert_fr)
.collect::<Vec<_>>()
.try_into()
.unwrap()
})
.collect();

let rounds_f_beginning = rounds_f / 2;
let p_end = rounds_f_beginning + rounds_p;
let internal_round_constants = round_constants
.drain(rounds_f_beginning..p_end)
.map(|vec| vec[0])
.collect::<Vec<_>>();
let external_round_constants = round_constants;
let params = Poseidon2Params {
rounds_f,
rounds_p,
mat_internal_diag_m_1: MAT_DIAG3_M_1
.iter()
.copied()
.map(convert_fr)
.collect_vec()
.try_into()
.unwrap(),
external_rc: external_round_constants,
internal_rc: internal_round_constants,
};

let mut state =
Poseidon2State::<Fr, T>::new(state_vars.map(|x| vars[&x.0]));
state.permutation(ctx, gate, &params);
for i in 0..T {
Poseidon2State::<Fr, POSEIDON2_T>::new(state_vars.map(|x| vars[&x.0]));
state.permutation(ctx, gate, &*POSEIDON2_PARAMS);
for i in 0..POSEIDON2_T {
*vars.get_mut(&state_vars[i].0).unwrap() = state.s[i];
}
}
Expand Down

0 comments on commit bc38e23

Please sign in to comment.