Skip to content
This repository has been archived by the owner on Apr 9, 2024. It is now read-only.

Commit

Permalink
feat(acvm_js): Add execute_circuit_with_black_box_solver to prevent…
Browse files Browse the repository at this point in the history
… reinitialization of `BlackBoxFunctionSolver` (#495)

Co-authored-by: sirasistant <[email protected]>
Co-authored-by: Tom French <[email protected]>
  • Loading branch information
3 people authored Aug 18, 2023
1 parent 14975ef commit 3877e0e
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 32 deletions.
12 changes: 6 additions & 6 deletions acvm/src/pwg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ impl From<BlackBoxResolutionError> for OpcodeResolutionError {
}
}

pub struct ACVM<B: BlackBoxFunctionSolver> {
pub struct ACVM<'backend, B: BlackBoxFunctionSolver> {
status: ACVMStatus,

backend: B,
backend: &'backend B,

/// Stores the solver for memory operations acting on blocks of memory disambiguated by [block][`BlockId`].
block_solvers: HashMap<BlockId, MemoryOpSolver>,
Expand All @@ -142,8 +142,8 @@ pub struct ACVM<B: BlackBoxFunctionSolver> {
witness_map: WitnessMap,
}

impl<B: BlackBoxFunctionSolver> ACVM<B> {
pub fn new(backend: B, opcodes: Vec<Opcode>, initial_witness: WitnessMap) -> Self {
impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> {
pub fn new(backend: &'backend B, opcodes: Vec<Opcode>, initial_witness: WitnessMap) -> Self {
let status = if opcodes.is_empty() { ACVMStatus::Solved } else { ACVMStatus::InProgress };
ACVM {
status,
Expand Down Expand Up @@ -246,7 +246,7 @@ impl<B: BlackBoxFunctionSolver> ACVM<B> {
let resolution = match opcode {
Opcode::Arithmetic(expr) => ArithmeticSolver::solve(&mut self.witness_map, expr),
Opcode::BlackBoxFuncCall(bb_func) => {
blackbox::solve(&self.backend, &mut self.witness_map, bb_func)
blackbox::solve(self.backend, &mut self.witness_map, bb_func)
}
Opcode::Directive(directive) => solve_directives(&mut self.witness_map, directive),
Opcode::MemoryInit { block_id, init } => {
Expand All @@ -258,7 +258,7 @@ impl<B: BlackBoxFunctionSolver> ACVM<B> {
solver.solve_memory_op(op, &mut self.witness_map)
}
Opcode::Brillig(brillig) => {
match BrilligSolver::solve(&mut self.witness_map, brillig, &self.backend) {
match BrilligSolver::solve(&mut self.witness_map, brillig, self.backend) {
Ok(Some(foreign_call)) => return self.wait_for_foreign_call(foreign_call),
res => res.map(|_| ()),
}
Expand Down
14 changes: 7 additions & 7 deletions acvm/tests/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ fn inversion_brillig_oracle_equivalence() {
])
.into();

let mut acvm = ACVM::new(StubbedBackend, opcodes, witness_assignments);
let mut acvm = ACVM::new(&StubbedBackend, opcodes, witness_assignments);
// use the partial witness generation solver with our acir program
let solver_status = acvm.solve();

Expand Down Expand Up @@ -256,7 +256,7 @@ fn double_inversion_brillig_oracle() {
])
.into();

let mut acvm = ACVM::new(StubbedBackend, opcodes, witness_assignments);
let mut acvm = ACVM::new(&StubbedBackend, opcodes, witness_assignments);

// use the partial witness generation solver with our acir program
let solver_status = acvm.solve();
Expand Down Expand Up @@ -377,7 +377,7 @@ fn oracle_dependent_execution() {
let witness_assignments =
BTreeMap::from([(w_x, FieldElement::from(2u128)), (w_y, FieldElement::from(2u128))]).into();

let mut acvm = ACVM::new(StubbedBackend, opcodes, witness_assignments);
let mut acvm = ACVM::new(&StubbedBackend, opcodes, witness_assignments);

// use the partial witness generation solver with our acir program
let solver_status = acvm.solve();
Expand Down Expand Up @@ -499,7 +499,7 @@ fn brillig_oracle_predicate() {
])
.into();

let mut acvm = ACVM::new(StubbedBackend, opcodes, witness_assignments);
let mut acvm = ACVM::new(&StubbedBackend, opcodes, witness_assignments);
let solver_status = acvm.solve();
assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved");

Expand Down Expand Up @@ -533,7 +533,7 @@ fn unsatisfied_opcode_resolved() {
values.insert(d, FieldElement::from(2_i128));

let opcodes = vec![Opcode::Arithmetic(gate_a)];
let mut acvm = ACVM::new(StubbedBackend, opcodes, values);
let mut acvm = ACVM::new(&StubbedBackend, opcodes, values);
let solver_status = acvm.solve();
assert_eq!(
solver_status,
Expand Down Expand Up @@ -615,7 +615,7 @@ fn unsatisfied_opcode_resolved_brillig() {

let opcodes = vec![brillig_opcode, Opcode::Arithmetic(gate_a)];

let mut acvm = ACVM::new(StubbedBackend, opcodes, values);
let mut acvm = ACVM::new(&StubbedBackend, opcodes, values);
let solver_status = acvm.solve();
assert_eq!(
solver_status,
Expand Down Expand Up @@ -658,7 +658,7 @@ fn memory_operations() {

let opcodes = vec![init, read_op, expression];

let mut acvm = ACVM::new(StubbedBackend, opcodes, initial_witness);
let mut acvm = ACVM::new(&StubbedBackend, opcodes, initial_witness);
let solver_status = acvm.solve();
assert_eq!(solver_status, ACVMStatus::Solved);
let witness_map = acvm.finalize();
Expand Down
20 changes: 10 additions & 10 deletions acvm/tests/stdlib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ macro_rules! test_uint_inner {
let uint = $uint::new(w);
let (w, extra_gates, _) = uint.rol(y, 2);
let witness_assignments = BTreeMap::from([(Witness(1), fe)]).into();
let mut acvm = ACVM::new(StubbedBackend, extra_gates, witness_assignments);
let mut acvm = ACVM::new(&StubbedBackend, extra_gates, witness_assignments);
let solver_status = acvm.solve();

prop_assert_eq!(acvm.witness_map().get(&w.get_inner()).unwrap(), &FieldElement::from(result as u128));
Expand All @@ -89,7 +89,7 @@ macro_rules! test_uint_inner {
let uint = $uint::new(w);
let (w, extra_gates, _) = uint.ror(y, 2);
let witness_assignments = BTreeMap::from([(Witness(1), fe)]).into();
let mut acvm = ACVM::new(StubbedBackend, extra_gates, witness_assignments);
let mut acvm = ACVM::new(&StubbedBackend, extra_gates, witness_assignments);
let solver_status = acvm.solve();

prop_assert_eq!(acvm.witness_map().get(&w.get_inner()).unwrap(), &FieldElement::from(result as u128));
Expand All @@ -109,7 +109,7 @@ macro_rules! test_uint_inner {
let u32_2 = $uint::new(w2);
let (q_w, r_w, extra_gates, _) = $uint::euclidean_division(&u32_1, &u32_2, 3);
let witness_assignments = BTreeMap::from([(Witness(1), lhs),(Witness(2), rhs)]).into();
let mut acvm = ACVM::new(StubbedBackend, extra_gates, witness_assignments);
let mut acvm = ACVM::new(&StubbedBackend, extra_gates, witness_assignments);
let solver_status = acvm.solve();

prop_assert_eq!(acvm.witness_map().get(&q_w.get_inner()).unwrap(), &FieldElement::from(q as u128));
Expand All @@ -135,7 +135,7 @@ macro_rules! test_uint_inner {
let (w2, extra_gates, _) = w.add(&u32_3, num_witness);
gates.extend(extra_gates);
let witness_assignments = BTreeMap::from([(Witness(1), lhs), (Witness(2), rhs), (Witness(3), rhs_z)]).into();
let mut acvm = ACVM::new(StubbedBackend, gates, witness_assignments);
let mut acvm = ACVM::new(&StubbedBackend, gates, witness_assignments);
let solver_status = acvm.solve();

prop_assert_eq!(acvm.witness_map().get(&w2.get_inner()).unwrap(), &result);
Expand All @@ -160,7 +160,7 @@ macro_rules! test_uint_inner {
let (w2, extra_gates, _) = w.sub(&u32_3, num_witness);
gates.extend(extra_gates);
let witness_assignments = BTreeMap::from([(Witness(1), lhs), (Witness(2), rhs), (Witness(3), rhs_z)]).into();
let mut acvm = ACVM::new(StubbedBackend, gates, witness_assignments);
let mut acvm = ACVM::new(&StubbedBackend, gates, witness_assignments);
let solver_status = acvm.solve();

prop_assert_eq!(acvm.witness_map().get(&w2.get_inner()).unwrap(), &result);
Expand All @@ -175,7 +175,7 @@ macro_rules! test_uint_inner {
let u32_1 = $uint::new(w1);
let (w, extra_gates, _) = u32_1.leftshift(y, 2);
let witness_assignments = BTreeMap::from([(Witness(1), lhs)]).into();
let mut acvm = ACVM::new(StubbedBackend, extra_gates, witness_assignments);
let mut acvm = ACVM::new(&StubbedBackend, extra_gates, witness_assignments);
let solver_status = acvm.solve();

prop_assert_eq!(acvm.witness_map().get(&w.get_inner()).unwrap(), &FieldElement::from(result as u128));
Expand All @@ -190,7 +190,7 @@ macro_rules! test_uint_inner {
let u32_1 = $uint::new(w1);
let (w, extra_gates, _) = u32_1.rightshift(y, 2);
let witness_assignments = BTreeMap::from([(Witness(1), lhs)]).into();
let mut acvm = ACVM::new(StubbedBackend, extra_gates, witness_assignments);
let mut acvm = ACVM::new(&StubbedBackend, extra_gates, witness_assignments);
let solver_status = acvm.solve();

prop_assert_eq!(acvm.witness_map().get(&w.get_inner()).unwrap(), &FieldElement::from(result as u128));
Expand All @@ -208,7 +208,7 @@ macro_rules! test_uint_inner {
let u32_2 = $uint::new(w2);
let (w, extra_gates, _) = u32_1.less_than_comparison(&u32_2, 3);
let witness_assignments = BTreeMap::from([(Witness(1), lhs), (Witness(2), rhs)]).into();
let mut acvm = ACVM::new(StubbedBackend, extra_gates, witness_assignments);
let mut acvm = ACVM::new(&StubbedBackend, extra_gates, witness_assignments);
let solver_status = acvm.solve();

prop_assert_eq!(acvm.witness_map().get(&w.get_inner()).unwrap(), &FieldElement::from(result as u128));
Expand Down Expand Up @@ -290,7 +290,7 @@ macro_rules! test_hashes {
let circuit = compile(circuit, Language::PLONKCSat{ width: 3 }, $opcode_support).unwrap().0;

// solve witnesses
let mut acvm = ACVM::new(StubbedBackend, circuit.opcodes, witness_assignments.into());
let mut acvm = ACVM::new(&StubbedBackend, circuit.opcodes, witness_assignments.into());
let solver_status = acvm.solve();

prop_assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved");
Expand Down Expand Up @@ -347,7 +347,7 @@ proptest! {
let circuit = compile(circuit, Language::PLONKCSat{ width: 3 }, does_not_support_hash_to_field).unwrap().0;

// solve witnesses
let mut acvm = ACVM::new(StubbedBackend, circuit.opcodes, witness_assignments.into());
let mut acvm = ACVM::new(&StubbedBackend, circuit.opcodes, witness_assignments.into());
let solver_status = acvm.solve();

prop_assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved");
Expand Down
41 changes: 34 additions & 7 deletions acvm_js/src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,24 @@ use crate::{
JsWitnessMap,
};

struct SimulatedBackend {
#[wasm_bindgen]
pub struct WasmBlackBoxFunctionSolver {
blackbox_vendor: Barretenberg,
}

impl SimulatedBackend {
async fn initialize() -> SimulatedBackend {
impl WasmBlackBoxFunctionSolver {
async fn initialize() -> WasmBlackBoxFunctionSolver {
let blackbox_vendor = Barretenberg::new().await;
SimulatedBackend { blackbox_vendor }
WasmBlackBoxFunctionSolver { blackbox_vendor }
}
}

impl BlackBoxFunctionSolver for SimulatedBackend {
#[wasm_bindgen(js_name = "createBlackBoxSolver")]
pub async fn create_black_box_solver() -> WasmBlackBoxFunctionSolver {
WasmBlackBoxFunctionSolver::initialize().await
}

impl BlackBoxFunctionSolver for WasmBlackBoxFunctionSolver {
fn schnorr_verify(
&self,
public_key_x: &FieldElement,
Expand Down Expand Up @@ -76,10 +82,31 @@ pub async fn execute_circuit(
foreign_call_handler: ForeignCallHandler,
) -> Result<JsWitnessMap, js_sys::JsString> {
console_error_panic_hook::set_once();

let solver = WasmBlackBoxFunctionSolver::initialize().await;

execute_circuit_with_black_box_solver(&solver, circuit, initial_witness, foreign_call_handler)
.await
}

/// Executes an ACIR circuit to generate the solved witness from the initial witness.
///
/// @param {&WasmBlackBoxFunctionSolver} solver - A black box solver.
/// @param {Uint8Array} circuit - A serialized representation of an ACIR circuit
/// @param {WitnessMap} initial_witness - The initial witness map defining all of the inputs to `circuit`..
/// @param {ForeignCallHandler} foreign_call_handler - A callback to process any foreign calls from the circuit.
/// @returns {WitnessMap} The solved witness calculated by executing the circuit on the provided inputs.
#[wasm_bindgen(js_name = executeCircuitWithBlackBoxSolver, skip_jsdoc)]
pub async fn execute_circuit_with_black_box_solver(
solver: &WasmBlackBoxFunctionSolver,
circuit: Vec<u8>,
initial_witness: JsWitnessMap,
foreign_call_handler: ForeignCallHandler,
) -> Result<JsWitnessMap, js_sys::JsString> {
console_error_panic_hook::set_once();
let circuit: Circuit = Circuit::read(&*circuit).expect("Failed to deserialize circuit");

let backend = SimulatedBackend::initialize().await;
let mut acvm = ACVM::new(backend, circuit.opcodes, initial_witness.into());
let mut acvm = ACVM::new(solver, circuit.opcodes, initial_witness.into());

loop {
let solver_status = acvm.solve();
Expand Down
2 changes: 1 addition & 1 deletion acvm_js/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ cfg_if::cfg_if! {

pub use build_info::build_info;
pub use compression::{compress_witness, decompress_witness};
pub use execute::execute_circuit;
pub use execute::{execute_circuit, execute_circuit_with_black_box_solver, create_black_box_solver};
pub use js_witness_map::JsWitnessMap;
pub use logging::{init_log_level, LogLevel};
pub use public_witness::{get_public_parameters_witness, get_public_witness, get_return_witness};
Expand Down
37 changes: 37 additions & 0 deletions acvm_js/test/browser/execute_circuit.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import { expect } from "@esm-bundle/chai";
import initACVM, {
createBlackBoxSolver,
executeCircuit,
executeCircuitWithBlackBoxSolver,
WasmBlackBoxFunctionSolver,
WitnessMap,
initLogLevel,
ForeignCallHandler,
Expand Down Expand Up @@ -56,6 +59,7 @@ it("successfully processes simple brillig foreign call opcodes", async () => {

return oracleResponse;
};

const solved_witness: WitnessMap = await executeCircuit(
bytecode,
initialWitnessMap,
Expand Down Expand Up @@ -94,6 +98,7 @@ it("successfully processes complex brillig foreign call opcodes", async () => {

return oracleResponse;
};

const solved_witness: WitnessMap = await executeCircuit(
bytecode,
initialWitnessMap,
Expand Down Expand Up @@ -156,3 +161,35 @@ it("successfully executes a SchnorrVerify opcode", async () => {

expect(solvedWitness).to.be.deep.eq(expectedWitnessMap);
});

it("successfully executes two circuits with same backend", async function () {
// chose pedersen op here because it is the one with slow initialization
// that led to the decision to pull backend initialization into a separate
// function/wasmbind
const solver: WasmBlackBoxFunctionSolver = await createBlackBoxSolver();

const { bytecode, initialWitnessMap, expectedWitnessMap } = await import(
"../shared/pedersen"
);

const solvedWitness0: WitnessMap = await executeCircuitWithBlackBoxSolver(
solver,
bytecode,
initialWitnessMap,
() => {
throw Error("unexpected oracle");
}
);

expect(solvedWitness0).to.be.deep.eq(expectedWitnessMap);

const solvedWitness1: WitnessMap = await executeCircuitWithBlackBoxSolver(
solver,
bytecode,
initialWitnessMap,
() => {
throw Error("unexpected oracle");
}
);
expect(solvedWitness1).to.be.deep.eq(expectedWitnessMap);
});
39 changes: 39 additions & 0 deletions acvm_js/test/node/execute_circuit.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import { expect } from "chai";
import {
createBlackBoxSolver,
executeCircuit,
executeCircuitWithBlackBoxSolver,
WasmBlackBoxFunctionSolver,
WitnessMap,
ForeignCallHandler,
} from "../../../result/";
Expand Down Expand Up @@ -49,6 +52,7 @@ it("successfully processes simple brillig foreign call opcodes", async () => {

return oracleResponse;
};

const solved_witness: WitnessMap = await executeCircuit(
bytecode,
initialWitnessMap,
Expand Down Expand Up @@ -87,6 +91,7 @@ it("successfully processes complex brillig foreign call opcodes", async () => {

return oracleResponse;
};

const solved_witness: WitnessMap = await executeCircuit(
bytecode,
initialWitnessMap,
Expand Down Expand Up @@ -150,3 +155,37 @@ it("successfully executes a SchnorrVerify opcode", async () => {

expect(solvedWitness).to.be.deep.eq(expectedWitnessMap);
});

it("successfully executes two circuits with same backend", async function () {
this.timeout(10000);

// chose pedersen op here because it is the one with slow initialization
// that led to the decision to pull backend initialization into a separate
// function/wasmbind
const solver: WasmBlackBoxFunctionSolver = await createBlackBoxSolver();

const { bytecode, initialWitnessMap, expectedWitnessMap } = await import(
"../shared/pedersen"
);

const solvedWitness0 = await executeCircuitWithBlackBoxSolver(
solver,
bytecode,
initialWitnessMap,
() => {
throw Error("unexpected oracle");
}
);

const solvedWitness1 = await executeCircuitWithBlackBoxSolver(
solver,
bytecode,
initialWitnessMap,
() => {
throw Error("unexpected oracle");
}
);

expect(solvedWitness0).to.be.deep.eq(expectedWitnessMap);
expect(solvedWitness1).to.be.deep.eq(expectedWitnessMap);
});
Loading

0 comments on commit 3877e0e

Please sign in to comment.