diff --git a/Cargo.lock b/Cargo.lock index c34be7e54..554da41be 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4316,10 +4316,16 @@ dependencies = [ name = "openvm-sha256-air" version = "0.1.0-alpha" dependencies = [ + "hex", + "lazy_static", + "openvm-circuit", "openvm-circuit-primitives", "openvm-stark-backend", + "openvm-stark-sdk", "rand", "sha2", + "test-case", + "test-log", ] [[package]] diff --git a/crates/circuits/sha256-air/Cargo.toml b/crates/circuits/sha256-air/Cargo.toml index bb5cc5ce4..f094b2f7f 100644 --- a/crates/circuits/sha256-air/Cargo.toml +++ b/crates/circuits/sha256-air/Cargo.toml @@ -10,6 +10,13 @@ openvm-stark-backend = { workspace = true } sha2 = { version = "0.10", features = ["compress"] } rand.workspace = true +[dev-dependencies] +openvm-stark-sdk = { workspace = true } +test-case.workspace = true +test-log.workspace = true +lazy_static.workspace = true +openvm-circuit = { workspace = true, features = ["test-utils"] } +hex.workspace = true [features] default = ["parallel"] diff --git a/crates/circuits/sha256-air/src/lib.rs b/crates/circuits/sha256-air/src/lib.rs index b98979bf6..e5616b592 100644 --- a/crates/circuits/sha256-air/src/lib.rs +++ b/crates/circuits/sha256-air/src/lib.rs @@ -9,3 +9,6 @@ mod utils; pub use air::*; pub use columns::*; pub use utils::*; + +#[cfg(test)] +mod tests; diff --git a/crates/circuits/sha256-air/src/tests.rs b/crates/circuits/sha256-air/src/tests.rs index e69de29bb..b28ea7f80 100644 --- a/crates/circuits/sha256-air/src/tests.rs +++ b/crates/circuits/sha256-air/src/tests.rs @@ -0,0 +1,160 @@ +use std::{array, borrow::BorrowMut, cmp::max, sync::Arc}; + +use openvm_circuit::{ + arch::{ + instructions::riscv::RV32_CELL_BITS, testing::VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, + }, + utils::next_power_of_two_or_zero, +}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip}, + SubAir, +}; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + interaction::InteractionBuilder, + p3_air::{Air, BaseAir}, + p3_field::{AbstractField, Field, PrimeField32}, + p3_matrix::dense::RowMajorMatrix, + p3_maybe_rayon::prelude::{IndexedParallelIterator, ParallelIterator, ParallelSliceMut}, + prover::types::AirProofInput, + rap::{get_air_name, AnyRap, BaseAirWithPublicValues, PartitionedBaseAir}, + Chip, ChipUsageGetter, +}; +use openvm_stark_sdk::utils::create_seeded_rng; +use rand::Rng; + +use crate::{ + limbs_into_u32, Sha256Air, Sha256RoundCols, SHA256_BLOCK_U8S, SHA256_DIGEST_WIDTH, SHA256_H, + SHA256_ROUND_WIDTH, SHA256_ROWS_PER_BLOCK, SHA256_WORD_U8S, +}; + +// A wrapper AIR purely for testing purposes +#[derive(Clone, Debug)] +pub struct Sha256TestAir { + pub sub_air: Sha256Air, +} + +impl BaseAirWithPublicValues for Sha256TestAir {} +impl PartitionedBaseAir for Sha256TestAir {} +impl BaseAir for Sha256TestAir { + fn width(&self) -> usize { + >::width(&self.sub_air) + } +} + +impl Air for Sha256TestAir { + fn eval(&self, builder: &mut AB) { + self.sub_air.eval(builder, 0); + } +} + +// A wrapper Chip purely for testing purposes +#[derive(Debug)] +pub struct Sha256TestChip { + pub air: Sha256TestAir, + pub bitwise_lookup_chip: Arc>, + pub records: Vec<([u8; SHA256_BLOCK_U8S], bool)>, +} + +impl Chip for Sha256TestChip +where + Val: PrimeField32, +{ + fn air(&self) -> Arc> { + Arc::new(self.air.clone()) + } + + fn generate_air_proof_input(self) -> AirProofInput { + let non_padded_height = self.current_trace_height(); + let height = next_power_of_two_or_zero(non_padded_height); + let width = self.trace_width(); + let mut values = Val::::zero_vec(height * width); + let mut prev_hash = SHA256_H; + let mut local_block_idx = 0; + let mut global_block_idx = 1; + values + .chunks_exact_mut(width * SHA256_ROWS_PER_BLOCK) + .zip(self.records.iter()) + .for_each(|(block, record)| { + let (input, is_last_block) = record; + let input_words = array::from_fn(|i| { + limbs_into_u32::(array::from_fn(|j| { + input[i * SHA256_WORD_U8S + j] as u32 + })) + }); + self.air.sub_air.generate_block_trace( + block, + width, + 0, + &input_words, + self.bitwise_lookup_chip.as_ref(), + &prev_hash, + *is_last_block, + global_block_idx, + local_block_idx, + &[[Val::::ZERO; 16]; 4], + ); + global_block_idx += 1; + if *is_last_block { + local_block_idx = 0; + prev_hash = SHA256_H; + } else { + local_block_idx += 1; + prev_hash = Sha256Air::get_block_hash(&prev_hash, *input); + } + }); + values[width * non_padded_height..] + .par_chunks_mut(width) + .for_each(|row| { + let cols: &mut Sha256RoundCols> = row.borrow_mut(); + self.air.sub_air.generate_default_row(cols); + }); + values[width..] + .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK) + .take(non_padded_height / SHA256_ROWS_PER_BLOCK) + .for_each(|chunk| { + self.air.sub_air.generate_missing_cells(chunk, width, 0); + }); + + AirProofInput::simple(self.air(), RowMajorMatrix::new(values, width), vec![]) + } +} + +impl ChipUsageGetter for Sha256TestChip { + fn air_name(&self) -> String { + get_air_name(&self.air) + } + fn current_trace_height(&self) -> usize { + self.records.len() * SHA256_ROWS_PER_BLOCK + } + + fn trace_width(&self) -> usize { + max(SHA256_ROUND_WIDTH, SHA256_DIGEST_WIDTH) + } +} + +const SELF_BUS_IDX: usize = 28; +#[test] +fn rand_sha256_test() { + let mut rng = create_seeded_rng(); + let tester = VmChipTestBuilder::default(); + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + let len = rng.gen_range(1..100); + let random_records: Vec<_> = (0..len) + .map(|_| (array::from_fn(|_| rng.gen::()), true)) + .collect(); + let chip = Sha256TestChip { + air: Sha256TestAir { + sub_air: Sha256Air::new(bitwise_bus, SELF_BUS_IDX), + }, + bitwise_lookup_chip: bitwise_chip.clone(), + records: random_records, + }; + + let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + tester.simple_test().expect("Verification failed"); +} diff --git a/extensions/sha256/circuit/src/sha256_chip/tests.rs b/extensions/sha256/circuit/src/sha256_chip/tests.rs index 68d16c00b..91974fe35 100644 --- a/extensions/sha256/circuit/src/sha256_chip/tests.rs +++ b/extensions/sha256/circuit/src/sha256_chip/tests.rs @@ -92,7 +92,7 @@ fn rand_sha256_test() { Rv32Sha256Opcode::default_offset(), ); - let num_tests: usize = 1; + let num_tests: usize = 3; for _ in 0..num_tests { set_and_execute(&mut tester, &mut chip, &mut rng, SHA256, None, None); }