Skip to content
This repository has been archived by the owner on Apr 9, 2024. It is now read-only.

Commit

Permalink
feat(stdlib): Add fallback implementation of SHA256 black box funct…
Browse files Browse the repository at this point in the history
…ion (#407)

Co-authored-by: kevaundray <[email protected]>
  • Loading branch information
Ethan-000 and kevaundray authored Jul 11, 2023
1 parent 967ec81 commit 040369a
Show file tree
Hide file tree
Showing 12 changed files with 1,360 additions and 73 deletions.
3 changes: 3 additions & 0 deletions acvm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ async-trait = "0.1"
default = ["bn254"]
bn254 = ["acir/bn254", "stdlib/bn254", "brillig_vm/bn254", "blackbox_solver/bn254"]
bls12_381 = ["acir/bls12_381", "stdlib/bls12_381", "brillig_vm/bls12_381", "blackbox_solver/bls12_381"]
testing = ["stdlib/testing", "unstable-fallbacks"]
unstable-fallbacks = []

[dev-dependencies]
rand = "0.8.5"
proptest = "1.2.0"
20 changes: 17 additions & 3 deletions acvm/src/compiler/transformers/fallback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl FallbackTransformer {
lhs.num_bits, rhs.num_bits,
"number of bits specified for each input must be the same"
);
stdlib::fallback::and(
stdlib::blackbox_fallbacks::and(
Expression::from(lhs.witness),
Expression::from(rhs.witness),
*output,
Expand All @@ -88,7 +88,7 @@ impl FallbackTransformer {
lhs.num_bits, rhs.num_bits,
"number of bits specified for each input must be the same"
);
stdlib::fallback::xor(
stdlib::blackbox_fallbacks::xor(
Expression::from(lhs.witness),
Expression::from(rhs.witness),
*output,
Expand All @@ -98,12 +98,26 @@ impl FallbackTransformer {
}
BlackBoxFuncCall::RANGE { input } => {
// Note there are no outputs because range produces no outputs
stdlib::fallback::range(
stdlib::blackbox_fallbacks::range(
Expression::from(input.witness),
input.num_bits,
current_witness_idx,
)
}
#[cfg(feature = "unstable-fallbacks")]
BlackBoxFuncCall::SHA256 { inputs, outputs } => {
let mut sha256_inputs = Vec::new();
for input in inputs.iter() {
let witness_index = Expression::from(input.witness);
let num_bits = input.num_bits;
sha256_inputs.push((witness_index, num_bits));
}
stdlib::blackbox_fallbacks::sha256(
sha256_inputs,
outputs.to_vec(),
current_witness_idx,
)
}
_ => {
return Err(CompileError::UnsupportedBlackBox(gc.get_black_box_func()));
}
Expand Down
1 change: 1 addition & 0 deletions acvm/src/pwg/directives/sorting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ pub(super) fn route(inputs: Vec<FieldElement>, outputs: Vec<FieldElement>) -> Ve
mod tests {
use super::route;
use acir::FieldElement;
use proptest as _;
use rand::prelude::*;

fn execute_network(config: Vec<bool>, inputs: Vec<FieldElement>) -> Vec<FieldElement> {
Expand Down
2 changes: 1 addition & 1 deletion acvm/tests/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use acvm::{
};
use blackbox_solver::BlackBoxResolutionError;

struct StubbedBackend;
pub(crate) struct StubbedBackend;

impl BlackBoxFunctionSolver for StubbedBackend {
fn schnorr_verify(
Expand Down
202 changes: 202 additions & 0 deletions acvm/tests/stdlib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
#![cfg(feature = "testing")]
mod solver;
use crate::solver::StubbedBackend;
use acir::{
circuit::{
opcodes::{BlackBoxFuncCall, FunctionInput},
Circuit, Opcode, PublicInputs,
},
native_types::Witness,
FieldElement,
};
use acvm::{
compiler::{compile, CircuitSimplifier},
pwg::{ACVMStatus, ACVM},
Language,
};
use proptest::prelude::*;
use sha2::{Digest, Sha256};
use std::collections::{BTreeMap, BTreeSet};
use stdlib::blackbox_fallbacks::UInt32;

proptest! {
#[test]
fn test_uint32_ror(x in 0..u32::MAX, y in 0..32_u32) {
let fe = FieldElement::from(x as u128);
let w = Witness(1);
let result = x.rotate_right(y);
let sha256_u32 = UInt32::new(w);
let (w, extra_gates, _) = sha256_u32.ror(y, 2);
let witness_assignments = BTreeMap::from([(Witness(1), fe)]).into();
let mut acvm = ACVM::new(StubbedBackend, extra_gates, witness_assignments);
let solver_status = acvm.solve();

prop_assert_eq!(acvm.witness_map().get(&w.get_inner()).unwrap(), &FieldElement::from(result as u128));
prop_assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved");
}

#[test]
fn test_uint32_euclidean_division(x in 0..u32::MAX, y in 0..u32::MAX) {
let lhs = FieldElement::from(x as u128);
let rhs = FieldElement::from(y as u128);
let w1 = Witness(1);
let w2 = Witness(2);
let q = x.div_euclid(y);
let r = x.rem_euclid(y);
let u32_1 = UInt32::new(w1);
let u32_2 = UInt32::new(w2);
let (q_w, r_w, extra_gates, _) = UInt32::euclidean_division(&u32_1, &u32_2, 3);
let witness_assignments = BTreeMap::from([(Witness(1), lhs),(Witness(2), rhs)]).into();
let mut acvm = ACVM::new(StubbedBackend, extra_gates, witness_assignments);
let solver_status = acvm.solve();

prop_assert_eq!(acvm.witness_map().get(&q_w.get_inner()).unwrap(), &FieldElement::from(q as u128));
prop_assert_eq!(acvm.witness_map().get(&r_w.get_inner()).unwrap(), &FieldElement::from(r as u128));
prop_assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved");
}

#[test]
fn test_uint32_add(x in 0..u32::MAX, y in 0..u32::MAX, z in 0..u32::MAX) {
let lhs = FieldElement::from(x as u128);
let rhs = FieldElement::from(y as u128);
let rhs_z = FieldElement::from(z as u128);
let result = FieldElement::from(((x as u128).wrapping_add(y as u128) % (1_u128 << 32)).wrapping_add(z as u128) % (1_u128 << 32));
let w1 = Witness(1);
let w2 = Witness(2);
let w3 = Witness(3);
let u32_1 = UInt32::new(w1);
let u32_2 = UInt32::new(w2);
let u32_3 = UInt32::new(w3);
let mut gates = Vec::new();
let (w, extra_gates, num_witness) = u32_1.add(&u32_2, 4);
gates.extend(extra_gates);
let (w2, extra_gates, _) = w.add(&u32_3, num_witness);
gates.extend(extra_gates);
let witness_assignments = BTreeMap::from([(Witness(1), lhs), (Witness(2), rhs), (Witness(3), rhs_z)]).into();
let mut acvm = ACVM::new(StubbedBackend, gates, witness_assignments);
let solver_status = acvm.solve();

prop_assert_eq!(acvm.witness_map().get(&w2.get_inner()).unwrap(), &result);
prop_assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved");
}

#[test]
fn test_uint32_sub(x in 0..u32::MAX, y in 0..u32::MAX, z in 0..u32::MAX) {
let lhs = FieldElement::from(x as u128);
let rhs = FieldElement::from(y as u128);
let rhs_z = FieldElement::from(z as u128);
let result = FieldElement::from(((x as u128).wrapping_sub(y as u128) % (1_u128 << 32)).wrapping_sub(z as u128) % (1_u128 << 32));
let w1 = Witness(1);
let w2 = Witness(2);
let w3 = Witness(3);
let u32_1 = UInt32::new(w1);
let u32_2 = UInt32::new(w2);
let u32_3 = UInt32::new(w3);
let mut gates = Vec::new();
let (w, extra_gates, num_witness) = u32_1.sub(&u32_2, 4);
gates.extend(extra_gates);
let (w2, extra_gates, _) = w.sub(&u32_3, num_witness);
gates.extend(extra_gates);
let witness_assignments = BTreeMap::from([(Witness(1), lhs), (Witness(2), rhs), (Witness(3), rhs_z)]).into();
let mut acvm = ACVM::new(StubbedBackend, gates, witness_assignments);
let solver_status = acvm.solve();

prop_assert_eq!(acvm.witness_map().get(&w2.get_inner()).unwrap(), &result);
prop_assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved");
}

#[test]
fn test_uint32_left_shift(x in 0..u32::MAX, y in 0..32_u32) {
let lhs = FieldElement::from(x as u128);
let w1 = Witness(1);
let result = x.overflowing_shl(y).0;
let u32_1 = UInt32::new(w1);
let (w, extra_gates, _) = u32_1.leftshift(y, 2);
let witness_assignments = BTreeMap::from([(Witness(1), lhs)]).into();
let mut acvm = ACVM::new(StubbedBackend, extra_gates, witness_assignments);
let solver_status = acvm.solve();

prop_assert_eq!(acvm.witness_map().get(&w.get_inner()).unwrap(), &FieldElement::from(result as u128));
prop_assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved");
}

#[test]
fn test_uint32_right_shift(x in 0..u32::MAX, y in 0..32_u32) {
let lhs = FieldElement::from(x as u128);
let w1 = Witness(1);
let result = x.overflowing_shr(y).0;
let u32_1 = UInt32::new(w1);
let (w, extra_gates, _) = u32_1.rightshift(y, 2);
let witness_assignments = BTreeMap::from([(Witness(1), lhs)]).into();
let mut acvm = ACVM::new(StubbedBackend, extra_gates, witness_assignments);
let solver_status = acvm.solve();

prop_assert_eq!(acvm.witness_map().get(&w.get_inner()).unwrap(), &FieldElement::from(result as u128));
prop_assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved");
}
}

proptest! {
#![proptest_config(ProptestConfig::with_cases(3))]
#[test]
fn test_sha256(input_values in proptest::collection::vec(0..u8::MAX, 1..50)) {
let mut opcodes = Vec::new();
let mut witness_assignments = BTreeMap::new();
let mut sha256_input_witnesses: Vec<FunctionInput> = Vec::new();
let mut correct_result_witnesses: Vec<Witness> = Vec::new();
let mut output_witnesses: Vec<Witness> = Vec::new();

// prepare test data
hash_witnesses!(input_values, witness_assignments, sha256_input_witnesses, correct_result_witnesses, output_witnesses, Sha256);
let sha256_blackbox = Opcode::BlackBoxFuncCall(BlackBoxFuncCall::SHA256 { inputs: sha256_input_witnesses, outputs: output_witnesses });
opcodes.push(sha256_blackbox);

// compile circuit
let circuit_simplifier = CircuitSimplifier::new(witness_assignments.len() as u32 + 32);
let circuit = Circuit {current_witness_index: witness_assignments.len() as u32 + 32,
opcodes, public_parameters: PublicInputs(BTreeSet::new()), return_values: PublicInputs(BTreeSet::new()) };
let circuit = compile(circuit, Language::PLONKCSat{ width: 3 }, does_not_support_sha256, &circuit_simplifier).unwrap().0;

// solve witnesses
let mut acvm = ACVM::new(StubbedBackend, circuit.opcodes, witness_assignments.into());
let solver_status = acvm.solve();

prop_assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved");
}
}

fn does_not_support_sha256(opcode: &Opcode) -> bool {
!matches!(opcode, Opcode::BlackBoxFuncCall(BlackBoxFuncCall::SHA256 { .. }))
}

#[macro_export]
macro_rules! hash_witnesses {
(
$input_values:ident,
$witness_assignments:ident,
$input_witnesses: ident,
$correct_result_witnesses:ident,
$output_witnesses:ident,
$hasher:ident
) => {
let mut counter = 0;
let output = $hasher::digest($input_values.clone());
for inp_v in $input_values {
counter += 1;
let function_input = FunctionInput { witness: Witness(counter), num_bits: 8 };
$input_witnesses.push(function_input);
$witness_assignments.insert(Witness(counter), FieldElement::from(inp_v as u128));
}

for o_v in output {
counter += 1;
$correct_result_witnesses.push(Witness(counter));
$witness_assignments.insert(Witness(counter), FieldElement::from(o_v as u128));
}

for _ in 0..32 {
counter += 1;
$output_witnesses.push(Witness(counter));
}
};
}
1 change: 1 addition & 0 deletions stdlib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ acir.workspace = true
default = ["bn254"]
bn254 = ["acir/bn254"]
bls12_381 = ["acir/bls12_381"]
testing = ["bn254"]
Original file line number Diff line number Diff line change
@@ -1,76 +1,10 @@
use crate::helpers::VariableStore;
use super::utils::bit_decomposition;
use acir::{
acir_field::FieldElement,
circuit::{directives::Directive, Opcode},
circuit::Opcode,
native_types::{Expression, Witness},
};

// Perform bit decomposition on the provided expression
#[deprecated(note = "use bit_decomposition function instead")]
pub fn split(
gate: Expression,
bit_size: u32,
num_witness: u32,
new_gates: &mut Vec<Opcode>,
) -> Vec<Witness> {
let (extra_gates, bits, _) = bit_decomposition(gate, bit_size, num_witness);
new_gates.extend(extra_gates);
bits
}

// Generates opcodes and directives to bit decompose the input `gate`
// Returns the bits and the updated witness counter
// TODO:Ideally, we return the updated witness counter, or we require the input
// TODO to be a VariableStore. We are not doing this because we want migration to
// TODO be less painful
pub(crate) fn bit_decomposition(
gate: Expression,
bit_size: u32,
mut num_witness: u32,
) -> (Vec<Opcode>, Vec<Witness>, u32) {
let mut new_gates = Vec::new();
let mut variables = VariableStore::new(&mut num_witness);

// First create a witness for each bit
let mut bit_vector = Vec::with_capacity(bit_size as usize);
for _ in 0..bit_size {
bit_vector.push(variables.new_variable())
}

// Next create a directive which computes those bits.
new_gates.push(Opcode::Directive(Directive::ToLeRadix {
a: gate.clone(),
b: bit_vector.clone(),
radix: 2,
}));

// Now apply constraints to the bits such that they are the bit decomposition
// of the input and each bit is actually a bit
let mut binary_exprs = Vec::new();
let mut bit_decomp_constraint = gate;
let mut two_pow: FieldElement = FieldElement::one();
let two = FieldElement::from(2_i128);
for &bit in &bit_vector {
// Bit constraint to ensure each bit is a zero or one; bit^2 - bit = 0
let mut expr = Expression::default();
expr.push_multiplication_term(FieldElement::one(), bit, bit);
expr.push_addition_term(-FieldElement::one(), bit);
binary_exprs.push(Opcode::Arithmetic(expr));

// Constraint to ensure that the bits are constrained to be a bit decomposition
// of the input
// ie \sum 2^i * x_i = input
bit_decomp_constraint.push_addition_term(-two_pow, bit);
two_pow = two * two_pow;
}

new_gates.extend(binary_exprs);
bit_decomp_constraint.sort(); // TODO: we have an issue open to check if this is needed. Ideally, we remove it.
new_gates.push(Opcode::Arithmetic(bit_decomp_constraint));

(new_gates, bit_vector, variables.finalize())
}

// Range constraint
pub fn range(gate: Expression, bit_size: u32, num_witness: u32) -> (u32, Vec<Opcode>) {
let (new_gates, _, updated_witness_counter) = bit_decomposition(gate, bit_size, num_witness);
Expand Down
7 changes: 7 additions & 0 deletions stdlib/src/blackbox_fallbacks/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mod logic_fallbacks;
mod sha256;
mod uint32;
mod utils;
pub use logic_fallbacks::{and, range, xor};
pub use sha256::sha256;
pub use uint32::UInt32;
Loading

0 comments on commit 040369a

Please sign in to comment.