diff --git a/extensions/native/compiler/src/constraints/halo2/compiler.rs b/extensions/native/compiler/src/constraints/halo2/compiler.rs index 06a08b757f..b171f72dd2 100644 --- a/extensions/native/compiler/src/constraints/halo2/compiler.rs +++ b/extensions/native/compiler/src/constraints/halo2/compiler.rs @@ -3,7 +3,7 @@ use std::{ fmt::Debug, marker::PhantomData, panic::{catch_unwind, AssertUnwindSafe}, - sync::Arc, + sync::{Arc, LazyLock}, }; use ax_stark_backend::p3_field::{ExtensionField, PrimeField}; @@ -31,6 +31,54 @@ use crate::{ ir::{Config, DslIr, TracedVec, Witness}, }; +const POSEIDON2_T: usize = 3; +static POSEIDON2_PARAMS: LazyLock> = LazyLock::new(|| { + use zkhash::{ + ark_ff::{BigInteger, PrimeField as _}, + fields::bn256::FpBN256 as ark_FpBN256, + poseidon2::poseidon2_instance_bn256::{MAT_DIAG3_M_1, RC3}, + }; + + fn convert_fr(input: ark_FpBN256) -> Fr { + Fr::from_bytes_le(&input.into_bigint().to_bytes_le()) + } + const T: usize = 3; + let rounds_f = 8; + let rounds_p = 56; + let mut round_constants: Vec<[Fr; T]> = RC3 + .iter() + .map(|vec| { + vec.iter() + .cloned() + .map(convert_fr) + .collect::>() + .try_into() + .unwrap() + }) + .collect(); + + let rounds_f_beginning = rounds_f / 2; + let p_end = rounds_f_beginning + rounds_p; + let internal_round_constants = round_constants + .drain(rounds_f_beginning..p_end) + .map(|vec| vec[0]) + .collect::>(); + let external_round_constants = round_constants; + Poseidon2Params { + rounds_f, + rounds_p, + mat_internal_diag_m_1: MAT_DIAG3_M_1 + .iter() + .copied() + .map(convert_fr) + .collect_vec() + .try_into() + .unwrap(), + external_rc: external_round_constants, + internal_rc: internal_round_constants, + } +}); + /// The backend for the Halo2 constraint compiler. #[derive(Debug, Clone)] pub struct Halo2ConstraintCompiler { @@ -278,55 +326,10 @@ impl Halo2ConstraintCompiler { } } DslIr::CircuitPoseidon2Permute(state_vars) => { - use zkhash::{ - ark_ff::{BigInteger, PrimeField as _}, - fields::bn256::FpBN256 as ark_FpBN256, - poseidon2::poseidon2_instance_bn256::{MAT_DIAG3_M_1, RC3}, - }; - - fn convert_fr(input: ark_FpBN256) -> Fr { - Fr::from_bytes_le(&input.into_bigint().to_bytes_le()) - } - const T: usize = 3; - let rounds_f = 8; - let rounds_p = 56; - let mut round_constants: Vec<[Fr; T]> = RC3 - .iter() - .map(|vec| { - vec.iter() - .cloned() - .map(convert_fr) - .collect::>() - .try_into() - .unwrap() - }) - .collect(); - - let rounds_f_beginning = rounds_f / 2; - let p_end = rounds_f_beginning + rounds_p; - let internal_round_constants = round_constants - .drain(rounds_f_beginning..p_end) - .map(|vec| vec[0]) - .collect::>(); - let external_round_constants = round_constants; - let params = Poseidon2Params { - rounds_f, - rounds_p, - mat_internal_diag_m_1: MAT_DIAG3_M_1 - .iter() - .copied() - .map(convert_fr) - .collect_vec() - .try_into() - .unwrap(), - external_rc: external_round_constants, - internal_rc: internal_round_constants, - }; - let mut state = - Poseidon2State::::new(state_vars.map(|x| vars[&x.0])); - state.permutation(ctx, gate, ¶ms); - for i in 0..T { + Poseidon2State::::new(state_vars.map(|x| vars[&x.0])); + state.permutation(ctx, gate, &*POSEIDON2_PARAMS); + for i in 0..POSEIDON2_T { *vars.get_mut(&state_vars[i].0).unwrap() = state.s[i]; } }