diff --git a/acvm/Cargo.toml b/acvm/Cargo.toml index 751e5e1f0..518b8b6e3 100644 --- a/acvm/Cargo.toml +++ b/acvm/Cargo.toml @@ -28,6 +28,9 @@ async-trait = "0.1" default = ["bn254"] bn254 = ["acir/bn254", "stdlib/bn254", "brillig_vm/bn254", "blackbox_solver/bn254"] bls12_381 = ["acir/bls12_381", "stdlib/bls12_381", "brillig_vm/bls12_381", "blackbox_solver/bls12_381"] +testing = ["stdlib/testing", "unstable-fallbacks"] +unstable-fallbacks = [] [dev-dependencies] rand = "0.8.5" +proptest = "1.2.0" diff --git a/acvm/src/compiler/transformers/fallback.rs b/acvm/src/compiler/transformers/fallback.rs index 27047e2c1..095397a9a 100644 --- a/acvm/src/compiler/transformers/fallback.rs +++ b/acvm/src/compiler/transformers/fallback.rs @@ -75,7 +75,7 @@ impl FallbackTransformer { lhs.num_bits, rhs.num_bits, "number of bits specified for each input must be the same" ); - stdlib::fallback::and( + stdlib::blackbox_fallbacks::and( Expression::from(lhs.witness), Expression::from(rhs.witness), *output, @@ -88,7 +88,7 @@ impl FallbackTransformer { lhs.num_bits, rhs.num_bits, "number of bits specified for each input must be the same" ); - stdlib::fallback::xor( + stdlib::blackbox_fallbacks::xor( Expression::from(lhs.witness), Expression::from(rhs.witness), *output, @@ -98,12 +98,26 @@ impl FallbackTransformer { } BlackBoxFuncCall::RANGE { input } => { // Note there are no outputs because range produces no outputs - stdlib::fallback::range( + stdlib::blackbox_fallbacks::range( Expression::from(input.witness), input.num_bits, current_witness_idx, ) } + #[cfg(feature = "unstable-fallbacks")] + BlackBoxFuncCall::SHA256 { inputs, outputs } => { + let mut sha256_inputs = Vec::new(); + for input in inputs.iter() { + let witness_index = Expression::from(input.witness); + let num_bits = input.num_bits; + sha256_inputs.push((witness_index, num_bits)); + } + stdlib::blackbox_fallbacks::sha256( + sha256_inputs, + outputs.to_vec(), + current_witness_idx, + ) + } _ => { return Err(CompileError::UnsupportedBlackBox(gc.get_black_box_func())); } diff --git a/acvm/src/pwg/directives/sorting.rs b/acvm/src/pwg/directives/sorting.rs index baeb8677c..58d19ac48 100644 --- a/acvm/src/pwg/directives/sorting.rs +++ b/acvm/src/pwg/directives/sorting.rs @@ -247,6 +247,7 @@ pub(super) fn route(inputs: Vec, outputs: Vec) -> Ve mod tests { use super::route; use acir::FieldElement; + use proptest as _; use rand::prelude::*; fn execute_network(config: Vec, inputs: Vec) -> Vec { diff --git a/acvm/tests/solver.rs b/acvm/tests/solver.rs index 159d82826..59c3df5a7 100644 --- a/acvm/tests/solver.rs +++ b/acvm/tests/solver.rs @@ -17,7 +17,7 @@ use acvm::{ }; use blackbox_solver::BlackBoxResolutionError; -struct StubbedBackend; +pub(crate) struct StubbedBackend; impl BlackBoxFunctionSolver for StubbedBackend { fn schnorr_verify( diff --git a/acvm/tests/stdlib.rs b/acvm/tests/stdlib.rs new file mode 100644 index 000000000..a52c6f5fd --- /dev/null +++ b/acvm/tests/stdlib.rs @@ -0,0 +1,202 @@ +#![cfg(feature = "testing")] +mod solver; +use crate::solver::StubbedBackend; +use acir::{ + circuit::{ + opcodes::{BlackBoxFuncCall, FunctionInput}, + Circuit, Opcode, PublicInputs, + }, + native_types::Witness, + FieldElement, +}; +use acvm::{ + compiler::{compile, CircuitSimplifier}, + pwg::{ACVMStatus, ACVM}, + Language, +}; +use proptest::prelude::*; +use sha2::{Digest, Sha256}; +use std::collections::{BTreeMap, BTreeSet}; +use stdlib::blackbox_fallbacks::UInt32; + +proptest! { + #[test] + fn test_uint32_ror(x in 0..u32::MAX, y in 0..32_u32) { + let fe = FieldElement::from(x as u128); + let w = Witness(1); + let result = x.rotate_right(y); + let sha256_u32 = UInt32::new(w); + let (w, extra_gates, _) = sha256_u32.ror(y, 2); + let witness_assignments = BTreeMap::from([(Witness(1), fe)]).into(); + let mut acvm = ACVM::new(StubbedBackend, extra_gates, witness_assignments); + let solver_status = acvm.solve(); + + prop_assert_eq!(acvm.witness_map().get(&w.get_inner()).unwrap(), &FieldElement::from(result as u128)); + prop_assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved"); + } + + #[test] + fn test_uint32_euclidean_division(x in 0..u32::MAX, y in 0..u32::MAX) { + let lhs = FieldElement::from(x as u128); + let rhs = FieldElement::from(y as u128); + let w1 = Witness(1); + let w2 = Witness(2); + let q = x.div_euclid(y); + let r = x.rem_euclid(y); + let u32_1 = UInt32::new(w1); + let u32_2 = UInt32::new(w2); + let (q_w, r_w, extra_gates, _) = UInt32::euclidean_division(&u32_1, &u32_2, 3); + let witness_assignments = BTreeMap::from([(Witness(1), lhs),(Witness(2), rhs)]).into(); + let mut acvm = ACVM::new(StubbedBackend, extra_gates, witness_assignments); + let solver_status = acvm.solve(); + + prop_assert_eq!(acvm.witness_map().get(&q_w.get_inner()).unwrap(), &FieldElement::from(q as u128)); + prop_assert_eq!(acvm.witness_map().get(&r_w.get_inner()).unwrap(), &FieldElement::from(r as u128)); + prop_assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved"); + } + + #[test] + fn test_uint32_add(x in 0..u32::MAX, y in 0..u32::MAX, z in 0..u32::MAX) { + let lhs = FieldElement::from(x as u128); + let rhs = FieldElement::from(y as u128); + let rhs_z = FieldElement::from(z as u128); + let result = FieldElement::from(((x as u128).wrapping_add(y as u128) % (1_u128 << 32)).wrapping_add(z as u128) % (1_u128 << 32)); + let w1 = Witness(1); + let w2 = Witness(2); + let w3 = Witness(3); + let u32_1 = UInt32::new(w1); + let u32_2 = UInt32::new(w2); + let u32_3 = UInt32::new(w3); + let mut gates = Vec::new(); + let (w, extra_gates, num_witness) = u32_1.add(&u32_2, 4); + gates.extend(extra_gates); + let (w2, extra_gates, _) = w.add(&u32_3, num_witness); + gates.extend(extra_gates); + let witness_assignments = BTreeMap::from([(Witness(1), lhs), (Witness(2), rhs), (Witness(3), rhs_z)]).into(); + let mut acvm = ACVM::new(StubbedBackend, gates, witness_assignments); + let solver_status = acvm.solve(); + + prop_assert_eq!(acvm.witness_map().get(&w2.get_inner()).unwrap(), &result); + prop_assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved"); + } + + #[test] + fn test_uint32_sub(x in 0..u32::MAX, y in 0..u32::MAX, z in 0..u32::MAX) { + let lhs = FieldElement::from(x as u128); + let rhs = FieldElement::from(y as u128); + let rhs_z = FieldElement::from(z as u128); + let result = FieldElement::from(((x as u128).wrapping_sub(y as u128) % (1_u128 << 32)).wrapping_sub(z as u128) % (1_u128 << 32)); + let w1 = Witness(1); + let w2 = Witness(2); + let w3 = Witness(3); + let u32_1 = UInt32::new(w1); + let u32_2 = UInt32::new(w2); + let u32_3 = UInt32::new(w3); + let mut gates = Vec::new(); + let (w, extra_gates, num_witness) = u32_1.sub(&u32_2, 4); + gates.extend(extra_gates); + let (w2, extra_gates, _) = w.sub(&u32_3, num_witness); + gates.extend(extra_gates); + let witness_assignments = BTreeMap::from([(Witness(1), lhs), (Witness(2), rhs), (Witness(3), rhs_z)]).into(); + let mut acvm = ACVM::new(StubbedBackend, gates, witness_assignments); + let solver_status = acvm.solve(); + + prop_assert_eq!(acvm.witness_map().get(&w2.get_inner()).unwrap(), &result); + prop_assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved"); + } + + #[test] + fn test_uint32_left_shift(x in 0..u32::MAX, y in 0..32_u32) { + let lhs = FieldElement::from(x as u128); + let w1 = Witness(1); + let result = x.overflowing_shl(y).0; + let u32_1 = UInt32::new(w1); + let (w, extra_gates, _) = u32_1.leftshift(y, 2); + let witness_assignments = BTreeMap::from([(Witness(1), lhs)]).into(); + let mut acvm = ACVM::new(StubbedBackend, extra_gates, witness_assignments); + let solver_status = acvm.solve(); + + prop_assert_eq!(acvm.witness_map().get(&w.get_inner()).unwrap(), &FieldElement::from(result as u128)); + prop_assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved"); + } + + #[test] + fn test_uint32_right_shift(x in 0..u32::MAX, y in 0..32_u32) { + let lhs = FieldElement::from(x as u128); + let w1 = Witness(1); + let result = x.overflowing_shr(y).0; + let u32_1 = UInt32::new(w1); + let (w, extra_gates, _) = u32_1.rightshift(y, 2); + let witness_assignments = BTreeMap::from([(Witness(1), lhs)]).into(); + let mut acvm = ACVM::new(StubbedBackend, extra_gates, witness_assignments); + let solver_status = acvm.solve(); + + prop_assert_eq!(acvm.witness_map().get(&w.get_inner()).unwrap(), &FieldElement::from(result as u128)); + prop_assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved"); + } +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(3))] + #[test] + fn test_sha256(input_values in proptest::collection::vec(0..u8::MAX, 1..50)) { + let mut opcodes = Vec::new(); + let mut witness_assignments = BTreeMap::new(); + let mut sha256_input_witnesses: Vec = Vec::new(); + let mut correct_result_witnesses: Vec = Vec::new(); + let mut output_witnesses: Vec = Vec::new(); + + // prepare test data + hash_witnesses!(input_values, witness_assignments, sha256_input_witnesses, correct_result_witnesses, output_witnesses, Sha256); + let sha256_blackbox = Opcode::BlackBoxFuncCall(BlackBoxFuncCall::SHA256 { inputs: sha256_input_witnesses, outputs: output_witnesses }); + opcodes.push(sha256_blackbox); + + // compile circuit + let circuit_simplifier = CircuitSimplifier::new(witness_assignments.len() as u32 + 32); + let circuit = Circuit {current_witness_index: witness_assignments.len() as u32 + 32, + opcodes, public_parameters: PublicInputs(BTreeSet::new()), return_values: PublicInputs(BTreeSet::new()) }; + let circuit = compile(circuit, Language::PLONKCSat{ width: 3 }, does_not_support_sha256, &circuit_simplifier).unwrap().0; + + // solve witnesses + let mut acvm = ACVM::new(StubbedBackend, circuit.opcodes, witness_assignments.into()); + let solver_status = acvm.solve(); + + prop_assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved"); + } +} + +fn does_not_support_sha256(opcode: &Opcode) -> bool { + !matches!(opcode, Opcode::BlackBoxFuncCall(BlackBoxFuncCall::SHA256 { .. })) +} + +#[macro_export] +macro_rules! hash_witnesses { + ( + $input_values:ident, + $witness_assignments:ident, + $input_witnesses: ident, + $correct_result_witnesses:ident, + $output_witnesses:ident, + $hasher:ident + ) => { + let mut counter = 0; + let output = $hasher::digest($input_values.clone()); + for inp_v in $input_values { + counter += 1; + let function_input = FunctionInput { witness: Witness(counter), num_bits: 8 }; + $input_witnesses.push(function_input); + $witness_assignments.insert(Witness(counter), FieldElement::from(inp_v as u128)); + } + + for o_v in output { + counter += 1; + $correct_result_witnesses.push(Witness(counter)); + $witness_assignments.insert(Witness(counter), FieldElement::from(o_v as u128)); + } + + for _ in 0..32 { + counter += 1; + $output_witnesses.push(Witness(counter)); + } + }; +} diff --git a/stdlib/Cargo.toml b/stdlib/Cargo.toml index b5bb5d64d..fd05c02ce 100644 --- a/stdlib/Cargo.toml +++ b/stdlib/Cargo.toml @@ -17,3 +17,4 @@ acir.workspace = true default = ["bn254"] bn254 = ["acir/bn254"] bls12_381 = ["acir/bls12_381"] +testing = ["bn254"] diff --git a/stdlib/src/fallback.rs b/stdlib/src/blackbox_fallbacks/logic_fallbacks.rs similarity index 53% rename from stdlib/src/fallback.rs rename to stdlib/src/blackbox_fallbacks/logic_fallbacks.rs index 3eaab2932..7ae0d59e9 100644 --- a/stdlib/src/fallback.rs +++ b/stdlib/src/blackbox_fallbacks/logic_fallbacks.rs @@ -1,76 +1,10 @@ -use crate::helpers::VariableStore; +use super::utils::bit_decomposition; use acir::{ acir_field::FieldElement, - circuit::{directives::Directive, Opcode}, + circuit::Opcode, native_types::{Expression, Witness}, }; -// Perform bit decomposition on the provided expression -#[deprecated(note = "use bit_decomposition function instead")] -pub fn split( - gate: Expression, - bit_size: u32, - num_witness: u32, - new_gates: &mut Vec, -) -> Vec { - let (extra_gates, bits, _) = bit_decomposition(gate, bit_size, num_witness); - new_gates.extend(extra_gates); - bits -} - -// Generates opcodes and directives to bit decompose the input `gate` -// Returns the bits and the updated witness counter -// TODO:Ideally, we return the updated witness counter, or we require the input -// TODO to be a VariableStore. We are not doing this because we want migration to -// TODO be less painful -pub(crate) fn bit_decomposition( - gate: Expression, - bit_size: u32, - mut num_witness: u32, -) -> (Vec, Vec, u32) { - let mut new_gates = Vec::new(); - let mut variables = VariableStore::new(&mut num_witness); - - // First create a witness for each bit - let mut bit_vector = Vec::with_capacity(bit_size as usize); - for _ in 0..bit_size { - bit_vector.push(variables.new_variable()) - } - - // Next create a directive which computes those bits. - new_gates.push(Opcode::Directive(Directive::ToLeRadix { - a: gate.clone(), - b: bit_vector.clone(), - radix: 2, - })); - - // Now apply constraints to the bits such that they are the bit decomposition - // of the input and each bit is actually a bit - let mut binary_exprs = Vec::new(); - let mut bit_decomp_constraint = gate; - let mut two_pow: FieldElement = FieldElement::one(); - let two = FieldElement::from(2_i128); - for &bit in &bit_vector { - // Bit constraint to ensure each bit is a zero or one; bit^2 - bit = 0 - let mut expr = Expression::default(); - expr.push_multiplication_term(FieldElement::one(), bit, bit); - expr.push_addition_term(-FieldElement::one(), bit); - binary_exprs.push(Opcode::Arithmetic(expr)); - - // Constraint to ensure that the bits are constrained to be a bit decomposition - // of the input - // ie \sum 2^i * x_i = input - bit_decomp_constraint.push_addition_term(-two_pow, bit); - two_pow = two * two_pow; - } - - new_gates.extend(binary_exprs); - bit_decomp_constraint.sort(); // TODO: we have an issue open to check if this is needed. Ideally, we remove it. - new_gates.push(Opcode::Arithmetic(bit_decomp_constraint)); - - (new_gates, bit_vector, variables.finalize()) -} - // Range constraint pub fn range(gate: Expression, bit_size: u32, num_witness: u32) -> (u32, Vec) { let (new_gates, _, updated_witness_counter) = bit_decomposition(gate, bit_size, num_witness); diff --git a/stdlib/src/blackbox_fallbacks/mod.rs b/stdlib/src/blackbox_fallbacks/mod.rs new file mode 100644 index 000000000..263e60351 --- /dev/null +++ b/stdlib/src/blackbox_fallbacks/mod.rs @@ -0,0 +1,7 @@ +mod logic_fallbacks; +mod sha256; +mod uint32; +mod utils; +pub use logic_fallbacks::{and, range, xor}; +pub use sha256::sha256; +pub use uint32::UInt32; diff --git a/stdlib/src/blackbox_fallbacks/sha256.rs b/stdlib/src/blackbox_fallbacks/sha256.rs new file mode 100644 index 000000000..9a864142c --- /dev/null +++ b/stdlib/src/blackbox_fallbacks/sha256.rs @@ -0,0 +1,381 @@ +//! Sha256 fallback function. +use super::uint32::UInt32; +use super::utils::{byte_decomposition, round_to_nearest_byte}; +use crate::helpers::VariableStore; +use acir::{ + brillig::{self, BinaryFieldOp, RegisterIndex}, + circuit::{ + brillig::{Brillig, BrilligInputs, BrilligOutputs}, + opcodes::{BlackBoxFuncCall, FunctionInput}, + Opcode, + }, + native_types::{Expression, Witness}, + FieldElement, +}; + +const INIT_CONSTANTS: [u128; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +const ROUND_CONSTANTS: [u128; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +]; + +pub fn sha256( + inputs: Vec<(Expression, u32)>, + outputs: Vec, + mut num_witness: u32, +) -> (u32, Vec) { + let mut new_gates = Vec::new(); + let mut new_inputs = Vec::new(); + let mut total_num_bytes = 0; + + // Decompose the input field elements into bytes and collect the resulting witnesses. + for (witness, num_bits) in inputs { + let num_bytes = round_to_nearest_byte(num_bits); + total_num_bytes += num_bytes; + let (extra_gates, inputs, updated_witness_counter) = + byte_decomposition(witness, num_bytes, num_witness); + new_gates.extend(extra_gates); + new_inputs.extend(inputs); + num_witness = updated_witness_counter; + } + + let (result, num_witness, extra_gates) = + create_sha256_constraint(new_inputs, total_num_bytes, num_witness); + new_gates.extend(extra_gates); + + // constrain the outputs to be the same as the result of the circuit + for i in 0..outputs.len() { + let mut expr = Expression::from(outputs[i]); + expr.push_addition_term(-FieldElement::one(), result[i]); + new_gates.push(Opcode::Arithmetic(expr)); + } + (num_witness, new_gates) +} + +fn create_sha256_constraint( + mut input: Vec, + total_num_bytes: u32, + num_witness: u32, +) -> (Vec, u32, Vec) { + let mut new_gates = Vec::new(); + + // pad the bytes according to sha256 padding rules + let message_bits = total_num_bytes * 8; + let (mut num_witness, pad_witness, extra_gates) = pad(128, 8, num_witness); + new_gates.extend(extra_gates); + input.push(pad_witness); + let bytes_per_block = 64; + let num_bytes = (input.len() + 8) as u32; + let num_blocks = num_bytes / bytes_per_block + ((num_bytes % bytes_per_block != 0) as u32); + let num_total_bytes = num_blocks * bytes_per_block; + for _ in num_bytes..num_total_bytes { + let (updated_witness_counter, pad_witness, extra_gates) = pad(0, 8, num_witness); + num_witness = updated_witness_counter; + new_gates.extend(extra_gates); + input.push(pad_witness); + } + let (num_witness, pad_witness, extra_gates) = pad(message_bits, 64, num_witness); + new_gates.extend(extra_gates); + let (extra_gates, pad_witness, num_witness) = + byte_decomposition(pad_witness.into(), 8, num_witness); + new_gates.extend(extra_gates); + input.extend(pad_witness); + + // turn witness into u32 and load sha256 state + let (input, extra_gates, num_witness) = UInt32::from_witnesses(&input, num_witness); + new_gates.extend(extra_gates); + let (mut rolling_hash, extra_gates, num_witness) = prepare_state_constants(num_witness); + new_gates.extend(extra_gates); + let (round_constants, extra_gates, mut num_witness) = prepare_round_constants(num_witness); + new_gates.extend(extra_gates); + // split the input into blocks of size 16 + let input: Vec> = input.chunks(16).map(|block| block.to_vec()).collect(); + + // process sha256 blocks + for i in &input { + let (new_rolling_hash, extra_gates, updated_witness_counter) = + sha256_block(i, rolling_hash.clone(), round_constants.clone(), num_witness); + new_gates.extend(extra_gates); + num_witness = updated_witness_counter; + rolling_hash = new_rolling_hash; + } + + // decompose the result bytes in u32 to u8 + let (extra_gates, byte1, num_witness) = + byte_decomposition(Expression::from(rolling_hash[0].inner), 4, num_witness); + new_gates.extend(extra_gates); + let (extra_gates, byte2, num_witness) = + byte_decomposition(Expression::from(rolling_hash[1].inner), 4, num_witness); + new_gates.extend(extra_gates); + let (extra_gates, byte3, num_witness) = + byte_decomposition(Expression::from(rolling_hash[2].inner), 4, num_witness); + new_gates.extend(extra_gates); + let (extra_gates, byte4, num_witness) = + byte_decomposition(Expression::from(rolling_hash[3].inner), 4, num_witness); + new_gates.extend(extra_gates); + let (extra_gates, byte5, num_witness) = + byte_decomposition(Expression::from(rolling_hash[4].inner), 4, num_witness); + new_gates.extend(extra_gates); + let (extra_gates, byte6, num_witness) = + byte_decomposition(Expression::from(rolling_hash[5].inner), 4, num_witness); + new_gates.extend(extra_gates); + let (extra_gates, byte7, num_witness) = + byte_decomposition(Expression::from(rolling_hash[6].inner), 4, num_witness); + new_gates.extend(extra_gates); + let (extra_gates, byte8, num_witness) = + byte_decomposition(Expression::from(rolling_hash[7].inner), 4, num_witness); + new_gates.extend(extra_gates); + + let result = vec![byte1, byte2, byte3, byte4, byte5, byte6, byte7, byte8] + .into_iter() + .flatten() + .collect(); + + (result, num_witness, new_gates) +} + +fn pad(number: u32, bit_size: u32, mut num_witness: u32) -> (u32, Witness, Vec) { + let mut new_gates = Vec::new(); + let mut variables = VariableStore::new(&mut num_witness); + let pad = variables.new_variable(); + + let brillig_opcode = Opcode::Brillig(Brillig { + inputs: vec![ + BrilligInputs::Single(Expression { + mul_terms: vec![], + linear_combinations: vec![], + q_c: FieldElement::from(number as u128), + }), + BrilligInputs::Single(Expression::default()), + ], + outputs: vec![BrilligOutputs::Simple(pad)], + foreign_call_results: vec![], + bytecode: vec![brillig::Opcode::BinaryFieldOp { + op: BinaryFieldOp::Add, + lhs: RegisterIndex::from(0), + rhs: RegisterIndex::from(1), + destination: RegisterIndex::from(0), + }], + predicate: None, + }); + new_gates.push(brillig_opcode); + + let range = Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { + input: FunctionInput { witness: pad, num_bits: bit_size }, + }); + new_gates.push(range); + + (num_witness, pad, new_gates) +} + +fn sha256_block( + input: &[UInt32], + rolling_hash: Vec, + round_constants: Vec, + mut num_witness: u32, +) -> (Vec, Vec, u32) { + let mut new_gates = Vec::new(); + let mut w = Vec::new(); + w.extend(input.to_owned()); + + for i in 16..64 { + // calculate s0 `w[i - 15].ror(7) ^ w[i - 15].ror(18) ^ (w[i - 15] >> 3)` + let (a1, extra_gates, updated_witness_counter) = w[i - 15].ror(7, num_witness); + new_gates.extend(extra_gates); + let (a2, extra_gates, updated_witness_counter) = w[i - 15].ror(18, updated_witness_counter); + new_gates.extend(extra_gates); + let (a3, extra_gates, updated_witness_counter) = + w[i - 15].rightshift(3, updated_witness_counter); + new_gates.extend(extra_gates); + let (a4, extra_gates, updated_witness_counter) = a1.xor(a2, updated_witness_counter); + new_gates.extend(extra_gates); + let (s0, extra_gates, updated_witness_counter) = a4.xor(a3, updated_witness_counter); + new_gates.extend(extra_gates); + + // calculate s1 `w[i - 2].ror(17) ^ w[i - 2].ror(19) ^ (w[i - 2] >> 10)` + let (b1, extra_gates, updated_witness_counter) = w[i - 2].ror(17, updated_witness_counter); + new_gates.extend(extra_gates); + let (b2, extra_gates, updated_witness_counter) = w[i - 2].ror(19, updated_witness_counter); + new_gates.extend(extra_gates); + let (b3, extra_gates, updated_witness_counter) = + w[i - 2].rightshift(10, updated_witness_counter); + new_gates.extend(extra_gates); + let (b4, extra_gates, updated_witness_counter) = b1.xor(b2, updated_witness_counter); + new_gates.extend(extra_gates); + let (s1, extra_gates, updated_witness_counter) = b4.xor(b3, updated_witness_counter); + new_gates.extend(extra_gates); + + // calculate w[i] `w[i - 16] + w[i - 7] + s0 + s1` + let (c1, extra_gates, updated_witness_counter) = + w[i - 16].add(&w[i - 7], updated_witness_counter); + new_gates.extend(extra_gates); + let (c2, extra_gates, updated_witness_counter) = c1.add(&s0, updated_witness_counter); + new_gates.extend(extra_gates); + let (c3, extra_gates, updated_witness_counter) = c2.add(&s1, updated_witness_counter); + new_gates.extend(extra_gates); + w.push(c3); + num_witness = updated_witness_counter; + } + + let mut a = rolling_hash[0]; + let mut b = rolling_hash[1]; + let mut c = rolling_hash[2]; + let mut d = rolling_hash[3]; + let mut e = rolling_hash[4]; + let mut f = rolling_hash[5]; + let mut g = rolling_hash[6]; + let mut h = rolling_hash[7]; + + #[allow(non_snake_case)] + for i in 0..64 { + // calculate S1 `e.ror(6) ^ e.ror(11) ^ e.ror(25)` + let (a1, extra_gates, updated_witness_counter) = e.ror(6, num_witness); + new_gates.extend(extra_gates); + let (a2, extra_gates, updated_witness_counter) = e.ror(11, updated_witness_counter); + new_gates.extend(extra_gates); + let (a3, extra_gates, updated_witness_counter) = e.ror(25, updated_witness_counter); + new_gates.extend(extra_gates); + let (a4, extra_gates, updated_witness_counter) = a1.xor(a2, updated_witness_counter); + new_gates.extend(extra_gates); + let (S1, extra_gates, updated_witness_counter) = a4.xor(a3, updated_witness_counter); + new_gates.extend(extra_gates); + + // calculate ch `(e & f) + (~e & g)` + let (b1, extra_gates, updated_witness_counter) = e.and(&f, updated_witness_counter); + new_gates.extend(extra_gates); + let (b2, extra_gates, updated_witness_counter) = e.not(updated_witness_counter); + new_gates.extend(extra_gates); + let (b3, extra_gates, updated_witness_counter) = b2.and(&g, updated_witness_counter); + new_gates.extend(extra_gates); + let (ch, extra_gates, updated_witness_counter) = b1.add(&b3, updated_witness_counter); + new_gates.extend(extra_gates); + + // caculate temp1 `h + S1 + ch + round_constants[i] + w[i]` + let (c1, extra_gates, updated_witness_counter) = h.add(&S1, updated_witness_counter); + new_gates.extend(extra_gates); + let (c2, extra_gates, updated_witness_counter) = c1.add(&ch, updated_witness_counter); + new_gates.extend(extra_gates); + let (c3, extra_gates, updated_witness_counter) = + c2.add(&round_constants[i], updated_witness_counter); + new_gates.extend(extra_gates); + let (temp1, extra_gates, updated_witness_counter) = c3.add(&w[i], updated_witness_counter); + new_gates.extend(extra_gates); + + // calculate S0 `a.ror(2) ^ a.ror(13) ^ a.ror(22)` + let (d1, extra_gates, updated_witness_counter) = a.ror(2, updated_witness_counter); + new_gates.extend(extra_gates); + let (d2, extra_gates, updated_witness_counter) = a.ror(13, updated_witness_counter); + new_gates.extend(extra_gates); + let (d3, extra_gates, updated_witness_counter) = a.ror(22, updated_witness_counter); + new_gates.extend(extra_gates); + let (d4, extra_gates, updated_witness_counter) = d1.xor(d2, updated_witness_counter); + new_gates.extend(extra_gates); + let (S0, extra_gates, updated_witness_counter) = d4.xor(d3, updated_witness_counter); + new_gates.extend(extra_gates); + + // calculate T0 `b & c` + let (T0, extra_gates, updated_witness_counter) = b.and(&c, updated_witness_counter); + new_gates.extend(extra_gates); + + // calculate maj `(a & (b + c - (T0 + T0))) + T0` which is the same as `(a & b) ^ (a & c) ^ (b & c)` + let (e1, extra_gates, updated_witness_counter) = T0.add(&T0, updated_witness_counter); + new_gates.extend(extra_gates); + let (e2, extra_gates, updated_witness_counter) = c.sub(&e1, updated_witness_counter); + new_gates.extend(extra_gates); + let (e3, extra_gates, updated_witness_counter) = b.add(&e2, updated_witness_counter); + new_gates.extend(extra_gates); + let (e4, extra_gates, updated_witness_counter) = a.and(&e3, updated_witness_counter); + new_gates.extend(extra_gates); + let (maj, extra_gates, updated_witness_counter) = e4.add(&T0, updated_witness_counter); + new_gates.extend(extra_gates); + + // calculate temp2 `S0 + maj` + let (temp2, extra_gates, updated_witness_counter) = S0.add(&maj, updated_witness_counter); + new_gates.extend(extra_gates); + + h = g; + g = f; + f = e; + let (new_e, extra_gates, updated_witness_counter) = d.add(&temp1, updated_witness_counter); + new_gates.extend(extra_gates); + d = c; + c = b; + b = a; + let (new_a, extra_gates, updated_witness_counter) = + temp1.add(&temp2, updated_witness_counter); + new_gates.extend(extra_gates); + num_witness = updated_witness_counter; + a = new_a; + e = new_e; + } + + let mut output = Vec::new(); + let (output0, extra_gates, num_witness) = a.add(&rolling_hash[0], num_witness); + new_gates.extend(extra_gates); + let (output1, extra_gates, num_witness) = b.add(&rolling_hash[1], num_witness); + new_gates.extend(extra_gates); + let (output2, extra_gates, num_witness) = c.add(&rolling_hash[2], num_witness); + new_gates.extend(extra_gates); + let (output3, extra_gates, num_witness) = d.add(&rolling_hash[3], num_witness); + new_gates.extend(extra_gates); + let (output4, extra_gates, num_witness) = e.add(&rolling_hash[4], num_witness); + new_gates.extend(extra_gates); + let (output5, extra_gates, num_witness) = f.add(&rolling_hash[5], num_witness); + new_gates.extend(extra_gates); + let (output6, extra_gates, num_witness) = g.add(&rolling_hash[6], num_witness); + new_gates.extend(extra_gates); + let (output7, extra_gates, num_witness) = h.add(&rolling_hash[7], num_witness); + new_gates.extend(extra_gates); + + output.push(output0); + output.push(output1); + output.push(output2); + output.push(output3); + output.push(output4); + output.push(output5); + output.push(output6); + output.push(output7); + + (output, new_gates, num_witness) +} + +/// Load initial state constants of Sha256 +pub(crate) fn prepare_state_constants(mut num_witness: u32) -> (Vec, Vec, u32) { + let mut new_gates = Vec::new(); + let mut new_witnesses = Vec::new(); + + for i in INIT_CONSTANTS { + let (new_witness, extra_gates, updated_witness_counter) = + UInt32::load_constant(i, num_witness); + new_gates.extend(extra_gates); + new_witnesses.push(new_witness); + num_witness = updated_witness_counter; + } + + (new_witnesses, new_gates, num_witness) +} + +/// Load round constants of Sha256 +pub(crate) fn prepare_round_constants(mut num_witness: u32) -> (Vec, Vec, u32) { + let mut new_gates = Vec::new(); + let mut new_witnesses = Vec::new(); + + for i in ROUND_CONSTANTS { + let (new_witness, extra_gates, updated_witness_counter) = + UInt32::load_constant(i, num_witness); + new_gates.extend(extra_gates); + new_witnesses.push(new_witness); + num_witness = updated_witness_counter; + } + + (new_witnesses, new_gates, num_witness) +} diff --git a/stdlib/src/blackbox_fallbacks/uint32.rs b/stdlib/src/blackbox_fallbacks/uint32.rs new file mode 100644 index 000000000..799ae1876 --- /dev/null +++ b/stdlib/src/blackbox_fallbacks/uint32.rs @@ -0,0 +1,620 @@ +use crate::helpers::VariableStore; +use acir::{ + brillig::{self, RegisterIndex}, + circuit::{ + brillig::{Brillig, BrilligInputs, BrilligOutputs}, + directives::QuotientDirective, + opcodes::{BlackBoxFuncCall, FunctionInput}, + Opcode, + }, + native_types::{Expression, Witness}, + FieldElement, +}; + +/// UInt32 contains a witness that points to a field element that represents a u32 integer +/// It has a inner field of type [Witness] that points to the field element and width = 32 +// TODO: This can be generalized to u8, u64 and others if needed. +#[derive(Copy, Clone, Debug)] +pub struct UInt32 { + pub(crate) inner: Witness, + width: u32, +} + +impl UInt32 { + #[cfg(any(test, feature = "testing"))] + pub fn get_inner(&self) -> Witness { + self.inner + } +} + +impl UInt32 { + /// Initialize A new [UInt32] type with a [Witness] + pub fn new(witness: Witness) -> Self { + UInt32 { inner: witness, width: 32 } + } + + /// Load a [u128] constant into the circuit + // TODO: This is currently a u128 instead of a u32 because + // in some cases we want to load 2^32 which does not fit in u32 + pub(crate) fn load_constant( + constant: u128, + mut num_witness: u32, + ) -> (UInt32, Vec, u32) { + let mut new_gates = Vec::new(); + let mut variables = VariableStore::new(&mut num_witness); + let new_witness = variables.new_variable(); + + let brillig_opcode = Opcode::Brillig(Brillig { + inputs: vec![BrilligInputs::Single(Expression { + mul_terms: vec![], + linear_combinations: vec![], + q_c: FieldElement::from(constant), + })], + outputs: vec![BrilligOutputs::Simple(new_witness)], + foreign_call_results: vec![], + bytecode: vec![brillig::Opcode::Stop], + predicate: None, + }); + new_gates.push(brillig_opcode); + let num_witness = variables.finalize(); + + (UInt32::new(new_witness), new_gates, num_witness) + } + + /// Load a [UInt32] from four [Witness]es each representing a [u8] + pub(crate) fn from_witnesses( + witnesses: &[Witness], + mut num_witness: u32, + ) -> (Vec, Vec, u32) { + let mut new_gates = Vec::new(); + let mut variables = VariableStore::new(&mut num_witness); + let mut uint32 = Vec::new(); + + for i in 0..witnesses.len() / 4 { + let new_witness = variables.new_variable(); + let brillig_opcode = Opcode::Brillig(Brillig { + inputs: vec![ + BrilligInputs::Single(Expression { + mul_terms: vec![], + linear_combinations: vec![(FieldElement::one(), witnesses[i * 4])], + q_c: FieldElement::zero(), + }), + BrilligInputs::Single(Expression { + mul_terms: vec![], + linear_combinations: vec![(FieldElement::one(), witnesses[i * 4 + 1])], + q_c: FieldElement::zero(), + }), + BrilligInputs::Single(Expression { + mul_terms: vec![], + linear_combinations: vec![(FieldElement::one(), witnesses[i * 4 + 2])], + q_c: FieldElement::zero(), + }), + BrilligInputs::Single(Expression { + mul_terms: vec![], + linear_combinations: vec![(FieldElement::one(), witnesses[i * 4 + 3])], + q_c: FieldElement::zero(), + }), + BrilligInputs::Single(Expression { + mul_terms: vec![], + linear_combinations: vec![], + q_c: FieldElement::from(8_u128), + }), + ], + outputs: vec![BrilligOutputs::Simple(new_witness)], + foreign_call_results: vec![], + bytecode: vec![ + brillig::Opcode::BinaryIntOp { + op: brillig::BinaryIntOp::Shl, + bit_size: 32, + lhs: RegisterIndex::from(0), + rhs: RegisterIndex::from(4), + destination: RegisterIndex::from(0), + }, + brillig::Opcode::BinaryIntOp { + op: brillig::BinaryIntOp::Add, + bit_size: 32, + lhs: RegisterIndex::from(0), + rhs: RegisterIndex::from(1), + destination: RegisterIndex::from(0), + }, + brillig::Opcode::BinaryIntOp { + op: brillig::BinaryIntOp::Shl, + bit_size: 32, + lhs: RegisterIndex::from(0), + rhs: RegisterIndex::from(4), + destination: RegisterIndex::from(0), + }, + brillig::Opcode::BinaryIntOp { + op: brillig::BinaryIntOp::Add, + bit_size: 32, + lhs: RegisterIndex::from(0), + rhs: RegisterIndex::from(2), + destination: RegisterIndex::from(0), + }, + brillig::Opcode::BinaryIntOp { + op: brillig::BinaryIntOp::Shl, + bit_size: 32, + lhs: RegisterIndex::from(0), + rhs: RegisterIndex::from(4), + destination: RegisterIndex::from(0), + }, + brillig::Opcode::BinaryIntOp { + op: brillig::BinaryIntOp::Add, + bit_size: 32, + lhs: RegisterIndex::from(0), + rhs: RegisterIndex::from(3), + destination: RegisterIndex::from(0), + }, + ], + predicate: None, + }); + uint32.push(UInt32::new(new_witness)); + new_gates.push(brillig_opcode); + let mut expr = Expression::from(new_witness); + for j in 0..4 { + let scaling_factor_value = 1 << (8 * (3 - j) as u32); + let scaling_factor = FieldElement::from(scaling_factor_value as u128); + expr.push_addition_term(-scaling_factor, witnesses[i * 4 + j]); + } + + new_gates.push(Opcode::Arithmetic(expr)); + } + let num_witness = variables.finalize(); + + (uint32, new_gates, num_witness) + } + + /// Returns the quotient and remainder such that lhs = rhs * quotient + remainder + // This should be the same as its equivalent in the Noir repo + pub fn euclidean_division( + lhs: &UInt32, + rhs: &UInt32, + mut num_witness: u32, + ) -> (UInt32, UInt32, Vec, u32) { + let mut new_gates = Vec::new(); + let mut variables = VariableStore::new(&mut num_witness); + let q_witness = variables.new_variable(); + let r_witness = variables.new_variable(); + + // compute quotient using directive function + let quotient_opcode = + Opcode::Directive(acir::circuit::directives::Directive::Quotient(QuotientDirective { + a: lhs.inner.into(), + b: rhs.inner.into(), + q: q_witness, + r: r_witness, + predicate: None, + })); + new_gates.push(quotient_opcode); + + // make sure r and q are in 32 bit range + let r_range_opcode = Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { + input: FunctionInput { witness: r_witness, num_bits: lhs.width }, + }); + let q_range_opcode = Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { + input: FunctionInput { witness: q_witness, num_bits: lhs.width }, + }); + new_gates.push(r_range_opcode); + new_gates.push(q_range_opcode); + let num_witness = variables.finalize(); + + // constrain r < rhs + let (rhs_sub_r, extra_gates, num_witness) = + rhs.sub_no_overflow(&UInt32::new(r_witness), num_witness); + new_gates.extend(extra_gates); + let rhs_sub_r_range_opcode = Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { + input: FunctionInput { witness: rhs_sub_r.inner, num_bits: lhs.width }, + }); + new_gates.push(rhs_sub_r_range_opcode); + + // constrain lhs = rhs * quotient + remainder + let rhs_expr = Expression::from(rhs.inner); + let lhs_constraint = Expression::from(lhs.inner); + let rhs_constraint = &rhs_expr * &Expression::from(q_witness); + let rhs_constraint = &rhs_constraint.unwrap() + &Expression::from(r_witness); + let div_euclidean = &lhs_constraint - &rhs_constraint; + new_gates.push(Opcode::Arithmetic(div_euclidean)); + + (UInt32::new(q_witness), UInt32::new(r_witness), new_gates, num_witness) + } + + /// Rotate right `rotation` bits. `(x >> rotation) | (x << (width - rotation))` + // Switched `or` with `add` here + // This should be the same as `u32.rotate_right(rotation)` in rust stdlib + pub fn ror(&self, rotation: u32, num_witness: u32) -> (UInt32, Vec, u32) { + let mut new_gates = Vec::new(); + + let (left_shift, extra_gates, num_witness) = self.leftshift(32 - rotation, num_witness); + new_gates.extend(extra_gates); + let (right_shift, extra_gates, num_witness) = self.rightshift(rotation, num_witness); + new_gates.extend(extra_gates); + let (result, extra_gates, num_witness) = left_shift.add(&right_shift, num_witness); + new_gates.extend(extra_gates); + + (result, new_gates, num_witness) + } + + /// left shift by `bits` + pub fn leftshift(&self, bits: u32, num_witness: u32) -> (UInt32, Vec, u32) { + let mut new_gates = Vec::new(); + let (two_pow_rhs, extra_gates, num_witness) = + UInt32::load_constant(2_u128.pow(bits), num_witness); + new_gates.extend(extra_gates); + let (left_shift, extra_gates, num_witness) = self.mul(&two_pow_rhs, num_witness); + new_gates.extend(extra_gates); + + (left_shift, new_gates, num_witness) + } + + /// right shift by `bits` + pub fn rightshift(&self, bits: u32, num_witness: u32) -> (UInt32, Vec, u32) { + let mut new_gates = Vec::new(); + let (two_pow_rhs, extra_gates, num_witness) = + UInt32::load_constant(2_u128.pow(bits), num_witness); + new_gates.extend(extra_gates); + let (right_shift, _, extra_gates, num_witness) = + UInt32::euclidean_division(self, &two_pow_rhs, num_witness); + new_gates.extend(extra_gates); + + (right_shift, new_gates, num_witness) + } + + /// Caculate and constrain `self` + `rhs` + pub fn add(&self, rhs: &UInt32, mut num_witness: u32) -> (UInt32, Vec, u32) { + let mut new_gates = Vec::new(); + let mut variables = VariableStore::new(&mut num_witness); + let new_witness = variables.new_variable(); + + // calculate `self` + `rhs` with overflow + let brillig_opcode = Opcode::Brillig(Brillig { + inputs: vec![ + BrilligInputs::Single(Expression { + mul_terms: vec![], + linear_combinations: vec![(FieldElement::one(), self.inner)], + q_c: FieldElement::zero(), + }), + BrilligInputs::Single(Expression { + mul_terms: vec![], + linear_combinations: vec![(FieldElement::one(), rhs.inner)], + q_c: FieldElement::zero(), + }), + ], + outputs: vec![BrilligOutputs::Simple(new_witness)], + foreign_call_results: vec![], + bytecode: vec![brillig::Opcode::BinaryIntOp { + op: brillig::BinaryIntOp::Add, + bit_size: 127, + lhs: RegisterIndex::from(0), + rhs: RegisterIndex::from(1), + destination: RegisterIndex::from(0), + }], + predicate: None, + }); + new_gates.push(brillig_opcode); + let num_witness = variables.finalize(); + + // constrain addition + let mut add_expr = Expression::from(new_witness); + add_expr.push_addition_term(-FieldElement::one(), self.inner); + add_expr.push_addition_term(-FieldElement::one(), rhs.inner); + new_gates.push(Opcode::Arithmetic(add_expr)); + + // mod 2^width to get final result as the remainder + let (two_pow_width, extra_gates, num_witness) = + UInt32::load_constant(2_u128.pow(self.width), num_witness); + new_gates.extend(extra_gates); + let (_, add_mod, extra_gates, num_witness) = + UInt32::euclidean_division(&UInt32::new(new_witness), &two_pow_width, num_witness); + new_gates.extend(extra_gates); + + (add_mod, new_gates, num_witness) + } + + /// Caculate and constrain `self` - `rhs` + pub fn sub(&self, rhs: &UInt32, mut num_witness: u32) -> (UInt32, Vec, u32) { + let mut new_gates = Vec::new(); + let mut variables = VariableStore::new(&mut num_witness); + let new_witness = variables.new_variable(); + + // calculate 2^32 + self - rhs to avoid overflow + let brillig_opcode = Opcode::Brillig(Brillig { + inputs: vec![ + BrilligInputs::Single(Expression { + mul_terms: vec![], + linear_combinations: vec![(FieldElement::one(), self.inner)], + q_c: FieldElement::zero(), + }), + BrilligInputs::Single(Expression { + mul_terms: vec![], + linear_combinations: vec![(FieldElement::one(), rhs.inner)], + q_c: FieldElement::zero(), + }), + BrilligInputs::Single(Expression { + mul_terms: vec![], + linear_combinations: vec![], + q_c: FieldElement::from(1_u128 << self.width), + }), + ], + outputs: vec![BrilligOutputs::Simple(new_witness)], + foreign_call_results: vec![], + bytecode: vec![ + brillig::Opcode::BinaryIntOp { + op: brillig::BinaryIntOp::Add, + bit_size: 127, + lhs: RegisterIndex::from(0), + rhs: RegisterIndex::from(2), + destination: RegisterIndex::from(0), + }, + brillig::Opcode::BinaryIntOp { + op: brillig::BinaryIntOp::Sub, + bit_size: 127, + lhs: RegisterIndex::from(0), + rhs: RegisterIndex::from(1), + destination: RegisterIndex::from(0), + }, + ], + predicate: None, + }); + new_gates.push(brillig_opcode); + let num_witness = variables.finalize(); + + // constrain subtraction + let mut sub_constraint = Expression::from(self.inner); + sub_constraint.push_addition_term(-FieldElement::one(), new_witness); + sub_constraint.push_addition_term(-FieldElement::one(), rhs.inner); + sub_constraint.q_c = FieldElement::from(1_u128 << self.width); + new_gates.push(Opcode::Arithmetic(sub_constraint)); + + // mod 2^width to get final result as the remainder + let (two_pow_width, extra_gates, num_witness) = + UInt32::load_constant(2_u128.pow(self.width), num_witness); + new_gates.extend(extra_gates); + let (_, sub_mod, extra_gates, num_witness) = + UInt32::euclidean_division(&UInt32::new(new_witness), &two_pow_width, num_witness); + new_gates.extend(extra_gates); + + (sub_mod, new_gates, num_witness) + } + + /// Calculate and constrain `self` - `rhs` - 1 without allowing overflow + /// This is a helper function to `euclidean_division` + // There is a `-1` because theres a case where rhs = 2^32 and remainder = 0 + pub(crate) fn sub_no_overflow( + &self, + rhs: &UInt32, + mut num_witness: u32, + ) -> (UInt32, Vec, u32) { + let mut new_gates = Vec::new(); + let mut variables = VariableStore::new(&mut num_witness); + let new_witness = variables.new_variable(); + + // calculate self - rhs - 1 + let brillig_opcode = Opcode::Brillig(Brillig { + inputs: vec![ + BrilligInputs::Single(Expression { + mul_terms: vec![], + linear_combinations: vec![(FieldElement::one(), self.inner)], + q_c: FieldElement::zero(), + }), + BrilligInputs::Single(Expression { + mul_terms: vec![], + linear_combinations: vec![(FieldElement::one(), rhs.inner)], + q_c: FieldElement::zero(), + }), + BrilligInputs::Single(Expression { + mul_terms: vec![], + linear_combinations: vec![], + q_c: FieldElement::one(), + }), + ], + outputs: vec![BrilligOutputs::Simple(new_witness)], + foreign_call_results: vec![], + bytecode: vec![ + brillig::Opcode::BinaryIntOp { + op: brillig::BinaryIntOp::Sub, + bit_size: 127, + lhs: RegisterIndex::from(0), + rhs: RegisterIndex::from(1), + destination: RegisterIndex::from(0), + }, + brillig::Opcode::BinaryIntOp { + op: brillig::BinaryIntOp::Sub, + bit_size: 127, + lhs: RegisterIndex::from(0), + rhs: RegisterIndex::from(2), + destination: RegisterIndex::from(0), + }, + ], + predicate: None, + }); + new_gates.push(brillig_opcode); + let num_witness = variables.finalize(); + + // constrain subtraction + let mut sub_constraint = Expression::from(self.inner); + sub_constraint.push_addition_term(-FieldElement::one(), new_witness); + sub_constraint.push_addition_term(-FieldElement::one(), rhs.inner); + sub_constraint.q_c = -FieldElement::one(); + new_gates.push(Opcode::Arithmetic(sub_constraint)); + + (UInt32::new(new_witness), new_gates, num_witness) + } + + /// Calculate and constrain `self` * `rhs` + pub(crate) fn mul(&self, rhs: &UInt32, mut num_witness: u32) -> (UInt32, Vec, u32) { + let mut new_gates = Vec::new(); + let mut variables = VariableStore::new(&mut num_witness); + let new_witness = variables.new_variable(); + + // calulate `self` * `rhs` with overflow + let brillig_opcode = Opcode::Brillig(Brillig { + inputs: vec![ + BrilligInputs::Single(Expression { + mul_terms: vec![], + linear_combinations: vec![(FieldElement::one(), self.inner)], + q_c: FieldElement::zero(), + }), + BrilligInputs::Single(Expression { + mul_terms: vec![], + linear_combinations: vec![(FieldElement::one(), rhs.inner)], + q_c: FieldElement::zero(), + }), + ], + outputs: vec![BrilligOutputs::Simple(new_witness)], + foreign_call_results: vec![], + bytecode: vec![brillig::Opcode::BinaryIntOp { + op: brillig::BinaryIntOp::Mul, + bit_size: 127, + lhs: RegisterIndex::from(0), + rhs: RegisterIndex::from(1), + destination: RegisterIndex::from(0), + }], + predicate: None, + }); + new_gates.push(brillig_opcode); + let num_witness = variables.finalize(); + + // constrain mul + let mut mul_constraint = Expression::from(new_witness); + mul_constraint.push_multiplication_term(-FieldElement::one(), self.inner, rhs.inner); + new_gates.push(Opcode::Arithmetic(mul_constraint)); + + // mod 2^width to get final result as the remainder + let (two_pow_rhs, extra_gates, num_witness) = + UInt32::load_constant(2_u128.pow(self.width), num_witness); + new_gates.extend(extra_gates); + let (_, mul_mod, extra_gates, num_witness) = + UInt32::euclidean_division(&UInt32::new(new_witness), &two_pow_rhs, num_witness); + new_gates.extend(extra_gates); + + (mul_mod, new_gates, num_witness) + } + + /// Calculate and constrain `self` and `rhs` + pub(crate) fn and(&self, rhs: &UInt32, mut num_witness: u32) -> (UInt32, Vec, u32) { + let mut new_gates = Vec::new(); + let mut variables = VariableStore::new(&mut num_witness); + let new_witness = variables.new_variable(); + + let brillig_opcode = Opcode::Brillig(Brillig { + inputs: vec![ + BrilligInputs::Single(Expression { + mul_terms: vec![], + linear_combinations: vec![(FieldElement::one(), self.inner)], + q_c: FieldElement::zero(), + }), + BrilligInputs::Single(Expression { + mul_terms: vec![], + linear_combinations: vec![(FieldElement::one(), rhs.inner)], + q_c: FieldElement::zero(), + }), + ], + outputs: vec![BrilligOutputs::Simple(new_witness)], + foreign_call_results: vec![], + bytecode: vec![brillig::Opcode::BinaryIntOp { + op: brillig::BinaryIntOp::And, + bit_size: 32, + lhs: RegisterIndex::from(0), + rhs: RegisterIndex::from(1), + destination: RegisterIndex::from(0), + }], + predicate: None, + }); + new_gates.push(brillig_opcode); + let num_witness = variables.finalize(); + + let and_opcode = Opcode::BlackBoxFuncCall(BlackBoxFuncCall::AND { + lhs: FunctionInput { witness: self.inner, num_bits: self.width }, + rhs: FunctionInput { witness: rhs.inner, num_bits: self.width }, + output: new_witness, + }); + new_gates.push(and_opcode); + + (UInt32::new(new_witness), new_gates, num_witness) + } + + /// Calculate and constrain `self` xor `rhs` + pub(crate) fn xor(&self, rhs: UInt32, mut num_witness: u32) -> (UInt32, Vec, u32) { + let mut new_gates = Vec::new(); + let mut variables = VariableStore::new(&mut num_witness); + let new_witness = variables.new_variable(); + + let brillig_opcode = Opcode::Brillig(Brillig { + inputs: vec![ + BrilligInputs::Single(Expression { + mul_terms: vec![], + linear_combinations: vec![(FieldElement::one(), self.inner)], + q_c: FieldElement::zero(), + }), + BrilligInputs::Single(Expression { + mul_terms: vec![], + linear_combinations: vec![(FieldElement::one(), rhs.inner)], + q_c: FieldElement::zero(), + }), + ], + outputs: vec![BrilligOutputs::Simple(new_witness)], + foreign_call_results: vec![], + bytecode: vec![brillig::Opcode::BinaryIntOp { + op: brillig::BinaryIntOp::Xor, + bit_size: 32, + lhs: RegisterIndex::from(0), + rhs: RegisterIndex::from(1), + destination: RegisterIndex::from(0), + }], + predicate: None, + }); + new_gates.push(brillig_opcode); + + let xor_opcode = Opcode::BlackBoxFuncCall(BlackBoxFuncCall::XOR { + lhs: FunctionInput { witness: self.inner, num_bits: self.width }, + rhs: FunctionInput { witness: rhs.inner, num_bits: self.width }, + output: new_witness, + }); + new_gates.push(xor_opcode); + + let num_witness = variables.finalize(); + + (UInt32::new(new_witness), new_gates, num_witness) + } + + /// Calculate and constrain not `self` + pub(crate) fn not(&self, mut num_witness: u32) -> (UInt32, Vec, u32) { + let mut new_gates = Vec::new(); + let mut variables = VariableStore::new(&mut num_witness); + let new_witness = variables.new_variable(); + + let brillig_opcode = Opcode::Brillig(Brillig { + inputs: vec![ + BrilligInputs::Single(Expression { + mul_terms: vec![], + linear_combinations: vec![(FieldElement::one(), self.inner)], + q_c: FieldElement::zero(), + }), + BrilligInputs::Single(Expression { + mul_terms: vec![], + linear_combinations: vec![], + q_c: FieldElement::from((1_u128 << self.width) - 1), + }), + ], + outputs: vec![BrilligOutputs::Simple(new_witness)], + foreign_call_results: vec![], + bytecode: vec![brillig::Opcode::BinaryIntOp { + op: brillig::BinaryIntOp::Sub, + bit_size: 32, + lhs: RegisterIndex::from(1), + rhs: RegisterIndex::from(0), + destination: RegisterIndex::from(0), + }], + predicate: None, + }); + new_gates.push(brillig_opcode); + let num_witness = variables.finalize(); + + let mut not_constraint = Expression::from(new_witness); + not_constraint.push_addition_term(FieldElement::one(), self.inner); + not_constraint.q_c = -FieldElement::from((1_u128 << self.width) - 1); + new_gates.push(Opcode::Arithmetic(not_constraint)); + + (UInt32::new(new_witness), new_gates, num_witness) + } +} diff --git a/stdlib/src/blackbox_fallbacks/utils.rs b/stdlib/src/blackbox_fallbacks/utils.rs new file mode 100644 index 000000000..6b4c12ec2 --- /dev/null +++ b/stdlib/src/blackbox_fallbacks/utils.rs @@ -0,0 +1,124 @@ +use crate::helpers::VariableStore; +use acir::{ + circuit::{ + directives::Directive, + opcodes::{BlackBoxFuncCall, FunctionInput}, + Opcode, + }, + native_types::{Expression, Witness}, + FieldElement, +}; + +fn round_to_nearest_mul_8(num_bits: u32) -> u32 { + let remainder = num_bits % 8; + + if remainder == 0 { + return num_bits; + } + + num_bits + 8 - remainder +} + +pub(crate) fn round_to_nearest_byte(num_bits: u32) -> u32 { + round_to_nearest_mul_8(num_bits) / 8 +} + +// Generates opcodes and directives to bit decompose the input `gate` +// Returns the bits and the updated witness counter +// TODO:Ideally, we return the updated witness counter, or we require the input +// TODO to be a VariableStore. We are not doing this because we want migration to +// TODO be less painful +pub(crate) fn bit_decomposition( + gate: Expression, + bit_size: u32, + mut num_witness: u32, +) -> (Vec, Vec, u32) { + let mut new_gates = Vec::new(); + let mut variables = VariableStore::new(&mut num_witness); + + // First create a witness for each bit + let mut bit_vector = Vec::with_capacity(bit_size as usize); + for _ in 0..bit_size { + bit_vector.push(variables.new_variable()) + } + + // Next create a directive which computes those bits. + new_gates.push(Opcode::Directive(Directive::ToLeRadix { + a: gate.clone(), + b: bit_vector.clone(), + radix: 2, + })); + + // Now apply constraints to the bits such that they are the bit decomposition + // of the input and each bit is actually a bit + let mut binary_exprs = Vec::new(); + let mut bit_decomp_constraint = gate; + let mut two_pow: FieldElement = FieldElement::one(); + let two = FieldElement::from(2_i128); + for &bit in &bit_vector { + // Bit constraint to ensure each bit is a zero or one; bit^2 - bit = 0 + let mut expr = Expression::default(); + expr.push_multiplication_term(FieldElement::one(), bit, bit); + expr.push_addition_term(-FieldElement::one(), bit); + binary_exprs.push(Opcode::Arithmetic(expr)); + + // Constraint to ensure that the bits are constrained to be a bit decomposition + // of the input + // ie \sum 2^i * x_i = input + bit_decomp_constraint.push_addition_term(-two_pow, bit); + two_pow = two * two_pow; + } + + new_gates.extend(binary_exprs); + bit_decomp_constraint.sort(); // TODO: we have an issue open to check if this is needed. Ideally, we remove it. + new_gates.push(Opcode::Arithmetic(bit_decomp_constraint)); + + (new_gates, bit_vector, variables.finalize()) +} + +// TODO: Maybe this can be merged with `bit_decomposition` +pub(crate) fn byte_decomposition( + gate: Expression, + num_bytes: u32, + mut num_witness: u32, +) -> (Vec, Vec, u32) { + let mut new_gates = Vec::new(); + let mut variables = VariableStore::new(&mut num_witness); + + // First create a witness for each byte + let mut vector = Vec::with_capacity(num_bytes as usize); + for _ in 0..num_bytes { + vector.push(variables.new_variable()) + } + + // Next create a directive which computes those byte. + new_gates.push(Opcode::Directive(Directive::ToLeRadix { + a: gate.clone(), + b: vector.clone(), + radix: 256, + })); + vector.reverse(); + + // Now apply constraints to the bytes such that they are the byte decomposition + // of the input and each byte is actually a byte + let mut byte_exprs = Vec::new(); + let mut decomp_constraint = gate; + let byte_shift: u128 = 256; + for (i, v) in vector.iter().enumerate() { + let range = Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { + input: FunctionInput { witness: *v, num_bits: 8 }, + }); + let scaling_factor_value = byte_shift.pow(num_bytes - 1 - i as u32); + let scaling_factor = FieldElement::from(scaling_factor_value); + + decomp_constraint.push_addition_term(-scaling_factor, *v); + + byte_exprs.push(range); + } + + new_gates.extend(byte_exprs); + decomp_constraint.sort(); + new_gates.push(Opcode::Arithmetic(decomp_constraint)); + + (new_gates, vector, variables.finalize()) +} diff --git a/stdlib/src/lib.rs b/stdlib/src/lib.rs index 02463066d..39d68647c 100644 --- a/stdlib/src/lib.rs +++ b/stdlib/src/lib.rs @@ -1,5 +1,5 @@ #![warn(unused_crate_dependencies)] #![warn(unreachable_pub)] -pub mod fallback; +pub mod blackbox_fallbacks; pub mod helpers;