diff --git a/noir/noir-repo/acvm-repo/acvm_js/src/execute.rs b/noir/noir-repo/acvm-repo/acvm_js/src/execute.rs index 0e58ccf039c..c97b8ea1a66 100644 --- a/noir/noir-repo/acvm-repo/acvm_js/src/execute.rs +++ b/noir/noir-repo/acvm-repo/acvm_js/src/execute.rs @@ -13,7 +13,8 @@ use wasm_bindgen::prelude::wasm_bindgen; use crate::{ foreign_call::{resolve_brillig, ForeignCallHandler}, - JsExecutionError, JsWitnessMap, JsWitnessStack, + public_witness::extract_indices, + JsExecutionError, JsSolvedAndReturnWitness, JsWitnessMap, JsWitnessStack, }; #[wasm_bindgen] @@ -58,6 +59,44 @@ pub async fn execute_circuit( Ok(witness_map.into()) } +/// Executes an ACIR circuit to generate the solved witness from the initial witness. +/// This method also extracts the public return values from the solved witness into its own return witness. +/// +/// @param {&WasmBlackBoxFunctionSolver} solver - A black box solver. +/// @param {Uint8Array} circuit - A serialized representation of an ACIR circuit +/// @param {WitnessMap} initial_witness - The initial witness map defining all of the inputs to `circuit`.. +/// @param {ForeignCallHandler} foreign_call_handler - A callback to process any foreign calls from the circuit. +/// @returns {SolvedAndReturnWitness} The solved witness calculated by executing the circuit on the provided inputs, as well as the return witness indices as specified by the circuit. +#[wasm_bindgen(js_name = executeCircuitWithReturnWitness, skip_jsdoc)] +pub async fn execute_circuit_with_return_witness( + solver: &WasmBlackBoxFunctionSolver, + program: Vec, + initial_witness: JsWitnessMap, + foreign_call_handler: ForeignCallHandler, +) -> Result { + console_error_panic_hook::set_once(); + + let program: Program = Program::deserialize_program(&program) + .map_err(|_| JsExecutionError::new("Failed to deserialize circuit. This is likely due to differing serialization formats between ACVM_JS and your compiler".to_string(), None))?; + + let mut witness_stack = execute_program_with_native_program_and_return( + solver, + &program, + initial_witness, + &foreign_call_handler, + ) + .await?; + let solved_witness = + witness_stack.pop().expect("Should have at least one witness on the stack").witness; + + let main_circuit = &program.functions[0]; + let return_witness = + extract_indices(&solved_witness, main_circuit.return_values.0.iter().copied().collect()) + .map_err(|err| JsExecutionError::new(err, None))?; + + Ok((solved_witness, return_witness).into()) +} + /// Executes an ACIR circuit to generate the solved witness from the initial witness. /// /// @param {&WasmBlackBoxFunctionSolver} solver - A black box solver. @@ -127,6 +166,21 @@ async fn execute_program_with_native_type_return( let program: Program = Program::deserialize_program(&program) .map_err(|_| JsExecutionError::new("Failed to deserialize circuit. This is likely due to differing serialization formats between ACVM_JS and your compiler".to_string(), None))?; + execute_program_with_native_program_and_return( + solver, + &program, + initial_witness, + foreign_call_executor, + ) + .await +} + +async fn execute_program_with_native_program_and_return( + solver: &WasmBlackBoxFunctionSolver, + program: &Program, + initial_witness: JsWitnessMap, + foreign_call_executor: &ForeignCallHandler, +) -> Result { let executor = ProgramExecutor::new(&program.functions, &solver.0, foreign_call_executor); let witness_stack = executor.execute(initial_witness.into()).await?; diff --git a/noir/noir-repo/acvm-repo/acvm_js/src/js_witness_map.rs b/noir/noir-repo/acvm-repo/acvm_js/src/js_witness_map.rs index 481b8caaa2d..c4482c4a234 100644 --- a/noir/noir-repo/acvm-repo/acvm_js/src/js_witness_map.rs +++ b/noir/noir-repo/acvm-repo/acvm_js/src/js_witness_map.rs @@ -2,13 +2,23 @@ use acvm::{ acir::native_types::{Witness, WitnessMap}, FieldElement, }; -use js_sys::{JsString, Map}; +use js_sys::{JsString, Map, Object}; use wasm_bindgen::prelude::{wasm_bindgen, JsValue}; #[wasm_bindgen(typescript_custom_section)] const WITNESS_MAP: &'static str = r#" // Map from witness index to hex string value of witness. export type WitnessMap = Map; + +/** + * An execution result containing two witnesses. + * 1. The full solved witness of the execution. + * 2. The return witness which contains the given public return values within the full witness. + */ +export type SolvedAndReturnWitness = { + solvedWitness: WitnessMap; + returnWitness: WitnessMap; +} "#; // WitnessMap @@ -21,6 +31,12 @@ extern "C" { #[wasm_bindgen(constructor, js_class = "Map")] pub fn new() -> JsWitnessMap; + #[wasm_bindgen(extends = Object, js_name = "SolvedAndReturnWitness", typescript_type = "SolvedAndReturnWitness")] + #[derive(Clone, Debug, PartialEq, Eq)] + pub type JsSolvedAndReturnWitness; + + #[wasm_bindgen(constructor, js_class = "Object")] + pub fn new() -> JsSolvedAndReturnWitness; } impl Default for JsWitnessMap { @@ -29,6 +45,12 @@ impl Default for JsWitnessMap { } } +impl Default for JsSolvedAndReturnWitness { + fn default() -> Self { + Self::new() + } +} + impl From for JsWitnessMap { fn from(witness_map: WitnessMap) -> Self { let js_map = JsWitnessMap::new(); @@ -54,6 +76,20 @@ impl From for WitnessMap { } } +impl From<(WitnessMap, WitnessMap)> for JsSolvedAndReturnWitness { + fn from(witness_maps: (WitnessMap, WitnessMap)) -> Self { + let js_solved_witness = JsWitnessMap::from(witness_maps.0); + let js_return_witness = JsWitnessMap::from(witness_maps.1); + + let entry_map = Map::new(); + entry_map.set(&JsValue::from_str("solvedWitness"), &js_solved_witness); + entry_map.set(&JsValue::from_str("returnWitness"), &js_return_witness); + + let solved_and_return_witness = Object::from_entries(&entry_map).unwrap(); + JsSolvedAndReturnWitness { obj: solved_and_return_witness } + } +} + pub(crate) fn js_value_to_field_element(js_value: JsValue) -> Result { let hex_str = js_value.as_string().ok_or("failed to parse field element from non-string")?; diff --git a/noir/noir-repo/acvm-repo/acvm_js/src/lib.rs b/noir/noir-repo/acvm-repo/acvm_js/src/lib.rs index d7ecc0ae192..66a4388b132 100644 --- a/noir/noir-repo/acvm-repo/acvm_js/src/lib.rs +++ b/noir/noir-repo/acvm-repo/acvm_js/src/lib.rs @@ -22,9 +22,10 @@ pub use compression::{ }; pub use execute::{ create_black_box_solver, execute_circuit, execute_circuit_with_black_box_solver, - execute_program, execute_program_with_black_box_solver, + execute_circuit_with_return_witness, execute_program, execute_program_with_black_box_solver, }; pub use js_execution_error::JsExecutionError; +pub use js_witness_map::JsSolvedAndReturnWitness; pub use js_witness_map::JsWitnessMap; pub use js_witness_stack::JsWitnessStack; pub use logging::init_log_level; diff --git a/noir/noir-repo/acvm-repo/acvm_js/src/public_witness.rs b/noir/noir-repo/acvm-repo/acvm_js/src/public_witness.rs index a0d5b5f8be2..4ba054732d4 100644 --- a/noir/noir-repo/acvm-repo/acvm_js/src/public_witness.rs +++ b/noir/noir-repo/acvm-repo/acvm_js/src/public_witness.rs @@ -7,7 +7,10 @@ use wasm_bindgen::prelude::wasm_bindgen; use crate::JsWitnessMap; -fn extract_indices(witness_map: &WitnessMap, indices: Vec) -> Result { +pub(crate) fn extract_indices( + witness_map: &WitnessMap, + indices: Vec, +) -> Result { let mut extracted_witness_map = WitnessMap::new(); for witness in indices { let witness_value = witness_map.get(&witness).ok_or(format!( @@ -44,7 +47,7 @@ pub fn get_return_witness( let witness_map = WitnessMap::from(witness_map); let return_witness = - extract_indices(&witness_map, circuit.return_values.0.clone().into_iter().collect())?; + extract_indices(&witness_map, circuit.return_values.0.iter().copied().collect())?; Ok(JsWitnessMap::from(return_witness)) } @@ -71,7 +74,7 @@ pub fn get_public_parameters_witness( let witness_map = WitnessMap::from(solved_witness); let public_params_witness = - extract_indices(&witness_map, circuit.public_parameters.0.clone().into_iter().collect())?; + extract_indices(&witness_map, circuit.public_parameters.0.iter().copied().collect())?; Ok(JsWitnessMap::from(public_params_witness)) } diff --git a/yarn-project/simulator/src/acvm/acvm.ts b/yarn-project/simulator/src/acvm/acvm.ts index 6d7101ba64e..d166b5d16c3 100644 --- a/yarn-project/simulator/src/acvm/acvm.ts +++ b/yarn-project/simulator/src/acvm/acvm.ts @@ -7,7 +7,7 @@ import { type ForeignCallInput, type ForeignCallOutput, type WasmBlackBoxFunctionSolver, - executeCircuitWithBlackBoxSolver, + executeCircuitWithReturnWitness, } from '@noir-lang/acvm_js'; import { traverseCauseChain } from '../common/errors.js'; @@ -27,9 +27,12 @@ type ACIRCallback = Record< */ export interface ACIRExecutionResult { /** - * The partial witness of the execution. + * An execution result contains two witnesses. + * 1. The partial witness of the execution. + * 2. The return witness which contains the given public return values within the full witness. */ partialWitness: ACVMWitness; + returnWitness: ACVMWitness; } /** @@ -89,7 +92,7 @@ export async function acvm( ): Promise { const logger = createDebugLogger('aztec:simulator:acvm'); - const partialWitness = await executeCircuitWithBlackBoxSolver( + const solvedAndReturnWitness = await executeCircuitWithReturnWitness( solver, acir, initialWitness, @@ -127,7 +130,7 @@ export async function acvm( throw err; }); - return { partialWitness }; + return { partialWitness: solvedAndReturnWitness.solvedWitness, returnWitness: solvedAndReturnWitness.returnWitness }; } /** diff --git a/yarn-project/simulator/src/acvm/deserialize.ts b/yarn-project/simulator/src/acvm/deserialize.ts index 74701582330..5936d381a37 100644 --- a/yarn-project/simulator/src/acvm/deserialize.ts +++ b/yarn-project/simulator/src/acvm/deserialize.ts @@ -1,7 +1,5 @@ import { Fr } from '@aztec/foundation/fields'; -import { getReturnWitness } from '@noir-lang/acvm_js'; - import { type ACVMField, type ACVMWitness } from './acvm_types.js'; /** @@ -32,13 +30,11 @@ export function frToBoolean(fr: Fr): boolean { } /** - * Extracts the return fields of a given partial witness. - * @param acir - The bytecode of the function. - * @param partialWitness - The witness to extract from. + * Transforms a witness map to its field elements. + * @param witness - The witness to extract from. * @returns The return values. */ -export function extractReturnWitness(acir: Buffer, partialWitness: ACVMWitness): Fr[] { - const returnWitness = getReturnWitness(acir, partialWitness); - const sortedKeys = [...returnWitness.keys()].sort((a, b) => a - b); - return sortedKeys.map(key => returnWitness.get(key)!).map(fromACVMField); +export function witnessMapToFields(witness: ACVMWitness): Fr[] { + const sortedKeys = [...witness.keys()].sort((a, b) => a - b); + return sortedKeys.map(key => witness.get(key)!).map(fromACVMField); } diff --git a/yarn-project/simulator/src/client/private_execution.ts b/yarn-project/simulator/src/client/private_execution.ts index 964e51cf95f..7789711b296 100644 --- a/yarn-project/simulator/src/client/private_execution.ts +++ b/yarn-project/simulator/src/client/private_execution.ts @@ -4,7 +4,7 @@ import { type AztecAddress } from '@aztec/foundation/aztec-address'; import { Fr } from '@aztec/foundation/fields'; import { createDebugLogger } from '@aztec/foundation/log'; -import { extractReturnWitness } from '../acvm/deserialize.js'; +import { witnessMapToFields } from '../acvm/deserialize.js'; import { Oracle, acvm, extractCallStack } from '../acvm/index.js'; import { ExecutionError } from '../common/errors.js'; import { type ClientExecutionContext } from './client_execution_context.js'; @@ -26,7 +26,7 @@ export async function executePrivateFunction( const acir = artifact.bytecode; const initialWitness = context.getInitialWitness(artifact); const acvmCallback = new Oracle(context); - const { partialWitness } = await acvm(await AcirSimulator.getSolver(), acir, initialWitness, acvmCallback).catch( + const acirExecutionResult = await acvm(await AcirSimulator.getSolver(), acir, initialWitness, acvmCallback).catch( (err: Error) => { throw new ExecutionError( err.message, @@ -39,8 +39,8 @@ export async function executePrivateFunction( ); }, ); - - const returnWitness = extractReturnWitness(acir, partialWitness); + const partialWitness = acirExecutionResult.partialWitness; + const returnWitness = witnessMapToFields(acirExecutionResult.returnWitness); const publicInputs = PrivateCircuitPublicInputs.fromFields(returnWitness); const encryptedLogs = context.getEncryptedLogs(); diff --git a/yarn-project/simulator/src/client/unconstrained_execution.ts b/yarn-project/simulator/src/client/unconstrained_execution.ts index 559a8cf1f48..d821ca9fea9 100644 --- a/yarn-project/simulator/src/client/unconstrained_execution.ts +++ b/yarn-project/simulator/src/client/unconstrained_execution.ts @@ -4,7 +4,7 @@ import { type AztecAddress } from '@aztec/foundation/aztec-address'; import { type Fr } from '@aztec/foundation/fields'; import { createDebugLogger } from '@aztec/foundation/log'; -import { extractReturnWitness } from '../acvm/deserialize.js'; +import { witnessMapToFields } from '../acvm/deserialize.js'; import { Oracle, acvm, extractCallStack, toACVMWitness } from '../acvm/index.js'; import { ExecutionError } from '../common/errors.js'; import { AcirSimulator } from './simulator.js'; @@ -27,7 +27,7 @@ export async function executeUnconstrainedFunction( const acir = artifact.bytecode; const initialWitness = toACVMWitness(0, args); - const { partialWitness } = await acvm( + const acirExecutionResult = await acvm( await AcirSimulator.getSolver(), acir, initialWitness, @@ -44,6 +44,7 @@ export async function executeUnconstrainedFunction( ); }); - return decodeReturnValues(artifact, extractReturnWitness(acir, partialWitness)); + const returnWitness = witnessMapToFields(acirExecutionResult.returnWitness); + return decodeReturnValues(artifact, returnWitness); } // docs:end:execute_unconstrained_function diff --git a/yarn-project/simulator/src/public/executor.ts b/yarn-project/simulator/src/public/executor.ts index 854223ae3cd..08e534b5da2 100644 --- a/yarn-project/simulator/src/public/executor.ts +++ b/yarn-project/simulator/src/public/executor.ts @@ -6,7 +6,7 @@ import { spawn } from 'child_process'; import fs from 'fs/promises'; import path from 'path'; -import { Oracle, acvm, extractCallStack, extractReturnWitness } from '../acvm/index.js'; +import { Oracle, acvm, extractCallStack, witnessMapToFields } from '../acvm/index.js'; import { AvmContext } from '../avm/avm_context.js'; import { AvmMachineState } from '../avm/avm_machine_state.js'; import { AvmSimulator } from '../avm/avm_simulator.js'; @@ -97,11 +97,12 @@ async function executePublicFunctionAcvm( const initialWitness = context.getInitialWitness(); const acvmCallback = new Oracle(context); - const { partialWitness, reverted, revertReason } = await (async () => { + const { partialWitness, returnWitnessMap, reverted, revertReason } = await (async () => { try { const result = await acvm(await AcirSimulator.getSolver(), acir, initialWitness, acvmCallback); return { partialWitness: result.partialWitness, + returnWitnessMap: result.returnWitness, reverted: false, revertReason: undefined, }; @@ -123,6 +124,7 @@ async function executePublicFunctionAcvm( } else { return { partialWitness: undefined, + returnWitnessMap: undefined, reverted: true, revertReason: createSimulationError(ee), }; @@ -159,7 +161,7 @@ async function executePublicFunctionAcvm( throw new Error('No partial witness returned from ACVM'); } - const returnWitness = extractReturnWitness(acir, partialWitness); + const returnWitness = witnessMapToFields(returnWitnessMap); const { returnValues, nullifierReadRequests: nullifierReadRequestsPadded,