diff --git a/crates/nargo/src/cli/execute_cmd.rs b/crates/nargo/src/cli/execute_cmd.rs index a89a0fcaa21..d88098c03c4 100644 --- a/crates/nargo/src/cli/execute_cmd.rs +++ b/crates/nargo/src/cli/execute_cmd.rs @@ -1,11 +1,13 @@ -use clap::ArgMatches; +use std::collections::BTreeMap; use std::path::{Path, PathBuf}; use acvm::acir::native_types::Witness; use acvm::{FieldElement, PartialWitnessGenerator}; +use clap::ArgMatches; +use iter_extended::{try_btree_map, vecmap}; use noirc_abi::errors::AbiError; use noirc_abi::input_parser::{Format, InputValue}; -use noirc_abi::{Abi, MAIN_RETURN_NAME}; +use noirc_abi::{decode_value, encode_value, AbiParameter, MAIN_RETURN_NAME}; use noirc_driver::CompiledProgram; use super::{create_named_dir, read_inputs_from_file, write_to_file, InputMap, WitnessMap}; @@ -40,10 +42,6 @@ pub(crate) fn run(args: ArgMatches) -> Result<(), CliError> { Ok(()) } -/// In Barretenberg, the proof system adds a zero witness in the first index, -/// So when we add witness values, their index start from 1. -const WITNESS_OFFSET: u32 = 1; - fn execute_with_path>( program_dir: P, show_ssa: bool, @@ -75,30 +73,12 @@ pub(crate) fn execute_program( Ok((return_value, solved_witness)) } -pub(crate) fn extract_public_inputs( - compiled_program: &CompiledProgram, - solved_witness: &WitnessMap, -) -> Result { - let encoded_public_inputs: Vec = compiled_program - .circuit - .public_inputs - .0 - .iter() - .map(|index| solved_witness[index]) - .collect(); - - let public_abi = compiled_program.abi.as_ref().unwrap().clone().public_abi(); - - public_abi.decode(&encoded_public_inputs) -} - pub(crate) fn solve_witness( compiled_program: &CompiledProgram, input_map: &InputMap, ) -> Result { - let abi = compiled_program.abi.as_ref().unwrap().clone(); - let mut solved_witness = - input_map_to_witness_map(abi, input_map).map_err(|error| match error { + let mut solved_witness = input_map_to_witness_map(input_map, &compiled_program.param_witnesses) + .map_err(|error| match error { AbiError::UndefinedInput(_) => { CliError::Generic(format!("{error} in the {PROVER_INPUT_FILE}.toml file.")) } @@ -115,18 +95,53 @@ pub(crate) fn solve_witness( /// /// In particular, this method shows one how to associate values in a Toml/JSON /// file with witness indices -fn input_map_to_witness_map(abi: Abi, input_map: &InputMap) -> Result { - // The ABI map is first encoded as a vector of field elements - let encoded_inputs = abi.encode(input_map, true)?; - - Ok(encoded_inputs - .into_iter() - .enumerate() - .map(|(index, witness_value)| { - let witness = Witness::new(WITNESS_OFFSET + (index as u32)); - (witness, witness_value) +fn input_map_to_witness_map( + input_map: &InputMap, + abi_witness_map: &BTreeMap>, +) -> Result { + // First encode each input separately + let encoded_input_map: BTreeMap> = + try_btree_map(input_map, |(key, value)| { + encode_value(value.clone(), key).map(|v| (key.clone(), v)) + })?; + + // Write input field elements into witness indices specified in `abi_witness_map`. + let witness_map = encoded_input_map + .iter() + .flat_map(|(param_name, encoded_param_fields)| { + let param_witness_indices = &abi_witness_map[param_name]; + param_witness_indices + .iter() + .zip(encoded_param_fields.iter()) + .map(|(&witness, &field_element)| (witness, field_element)) + }) + .collect(); + + Ok(witness_map) +} + +pub(crate) fn extract_public_inputs( + compiled_program: &CompiledProgram, + solved_witness: &WitnessMap, +) -> Result { + let public_abi = compiled_program.abi.as_ref().unwrap().clone().public_abi(); + + let public_inputs_map = public_abi + .parameters + .iter() + .map(|AbiParameter { name, typ, .. }| { + let param_witness_values = + vecmap(compiled_program.param_witnesses[name].clone(), |witness_index| { + solved_witness[&witness_index] + }); + + decode_value(&mut param_witness_values.into_iter(), typ) + .map(|input_value| (name.clone(), input_value)) + .unwrap() }) - .collect()) + .collect(); + + Ok(public_inputs_map) } pub(crate) fn save_witness_to_dir>( diff --git a/crates/noirc_abi/src/lib.rs b/crates/noirc_abi/src/lib.rs index 7d8bc645102..ad0606174f5 100644 --- a/crates/noirc_abi/src/lib.rs +++ b/crates/noirc_abi/src/lib.rs @@ -183,7 +183,7 @@ impl Abi { return Err(AbiError::TypeMismatch { param: param.to_owned(), value }); } - encoded_inputs.extend(Self::encode_value(value, ¶m.name)?); + encoded_inputs.extend(encode_value(value, ¶m.name)?); } // Check that no extra witness values have been provided. @@ -197,27 +197,6 @@ impl Abi { Ok(encoded_inputs) } - fn encode_value(value: InputValue, param_name: &String) -> Result, AbiError> { - let mut encoded_value = Vec::new(); - match value { - InputValue::Field(elem) => encoded_value.push(elem), - InputValue::Vec(vec_elem) => encoded_value.extend(vec_elem), - InputValue::String(string) => { - let str_as_fields = - string.bytes().map(|byte| FieldElement::from_be_bytes_reduce(&[byte])); - encoded_value.extend(str_as_fields) - } - InputValue::Struct(object) => { - for (field_name, value) in object { - let new_name = format!("{param_name}.{field_name}"); - encoded_value.extend(Self::encode_value(value, &new_name)?) - } - } - InputValue::Undefined => return Err(AbiError::UndefinedInput(param_name.to_string())), - } - Ok(encoded_value) - } - /// Decode a vector of `FieldElements` into the types specified in the ABI. pub fn decode( &self, @@ -234,59 +213,79 @@ impl Abi { let mut field_iterator = encoded_inputs.iter().cloned(); let mut decoded_inputs = BTreeMap::new(); for param in &self.parameters { - let decoded_value = Self::decode_value(&mut field_iterator, ¶m.typ)?; + let decoded_value = decode_value(&mut field_iterator, ¶m.typ)?; decoded_inputs.insert(param.name.to_owned(), decoded_value); } Ok(decoded_inputs) } +} - fn decode_value( - field_iterator: &mut impl Iterator, - value_type: &AbiType, - ) -> Result { - // This function assumes that `field_iterator` contains enough `FieldElement`s in order to decode a `value_type` - // `Abi.decode` enforces that the encoded inputs matches the expected length defined by the ABI so this is safe. - let value = match value_type { - AbiType::Field | AbiType::Integer { .. } | AbiType::Boolean => { - let field_element = field_iterator.next().unwrap(); - - InputValue::Field(field_element) +pub fn encode_value(value: InputValue, param_name: &String) -> Result, AbiError> { + let mut encoded_value = Vec::new(); + match value { + InputValue::Field(elem) => encoded_value.push(elem), + InputValue::Vec(vec_elem) => encoded_value.extend(vec_elem), + InputValue::String(string) => { + let str_as_fields = + string.bytes().map(|byte| FieldElement::from_be_bytes_reduce(&[byte])); + encoded_value.extend(str_as_fields) + } + InputValue::Struct(object) => { + for (field_name, value) in object { + let new_name = format!("{param_name}.{field_name}"); + encoded_value.extend(encode_value(value, &new_name)?) } - AbiType::Array { length, .. } => { - let field_elements: Vec = - field_iterator.take(*length as usize).collect(); + } + InputValue::Undefined => return Err(AbiError::UndefinedInput(param_name.to_string())), + } + Ok(encoded_value) +} - InputValue::Vec(field_elements) - } - AbiType::String { length } => { - let string_as_slice: Vec = field_iterator - .take(*length as usize) - .map(|e| { - let mut field_as_bytes = e.to_be_bytes(); - let char_byte = field_as_bytes.pop().unwrap(); // A character in a string is represented by a u8, thus we just want the last byte of the element - assert!(field_as_bytes.into_iter().all(|b| b == 0)); // Assert that the rest of the field element's bytes are empty - char_byte - }) - .collect(); - - let final_string = str::from_utf8(&string_as_slice).unwrap(); - - InputValue::String(final_string.to_owned()) - } - AbiType::Struct { fields, .. } => { - let mut struct_map = BTreeMap::new(); +pub fn decode_value( + field_iterator: &mut impl Iterator, + value_type: &AbiType, +) -> Result { + // This function assumes that `field_iterator` contains enough `FieldElement`s in order to decode a `value_type` + // `Abi.decode` enforces that the encoded inputs matches the expected length defined by the ABI so this is safe. + let value = match value_type { + AbiType::Field | AbiType::Integer { .. } | AbiType::Boolean => { + let field_element = field_iterator.next().unwrap(); + + InputValue::Field(field_element) + } + AbiType::Array { length, .. } => { + let field_elements: Vec = field_iterator.take(*length as usize).collect(); - for (field_key, param_type) in fields { - let field_value = Self::decode_value(field_iterator, param_type)?; + InputValue::Vec(field_elements) + } + AbiType::String { length } => { + let string_as_slice: Vec = field_iterator + .take(*length as usize) + .map(|e| { + let mut field_as_bytes = e.to_be_bytes(); + let char_byte = field_as_bytes.pop().unwrap(); // A character in a string is represented by a u8, thus we just want the last byte of the element + assert!(field_as_bytes.into_iter().all(|b| b == 0)); // Assert that the rest of the field element's bytes are empty + char_byte + }) + .collect(); + + let final_string = str::from_utf8(&string_as_slice).unwrap(); + + InputValue::String(final_string.to_owned()) + } + AbiType::Struct { fields, .. } => { + let mut struct_map = BTreeMap::new(); - struct_map.insert(field_key.to_owned(), field_value); - } + for (field_key, param_type) in fields { + let field_value = decode_value(field_iterator, param_type)?; - InputValue::Struct(struct_map) + struct_map.insert(field_key.to_owned(), field_value); } - }; - Ok(value) - } + InputValue::Struct(struct_map) + } + }; + + Ok(value) } diff --git a/crates/noirc_driver/src/lib.rs b/crates/noirc_driver/src/lib.rs index 69d19517585..116161f9a62 100644 --- a/crates/noirc_driver/src/lib.rs +++ b/crates/noirc_driver/src/lib.rs @@ -1,6 +1,7 @@ #![forbid(unsafe_code)] use acvm::acir::circuit::Circuit; +use acvm::acir::native_types::Witness; use acvm::Language; use fm::FileType; use noirc_abi::Abi; @@ -12,6 +13,7 @@ use noirc_frontend::hir::Context; use noirc_frontend::monomorphization::monomorphize; use noirc_frontend::node_interner::FuncId; use serde::{Deserialize, Serialize}; +use std::collections::BTreeMap; use std::path::{Path, PathBuf}; pub struct Driver { @@ -22,6 +24,7 @@ pub struct Driver { pub struct CompiledProgram { pub circuit: Circuit, pub abi: Option, + pub param_witnesses: BTreeMap>, } impl Driver { @@ -187,18 +190,21 @@ impl Driver { let program = monomorphize(main_function, &self.context.def_interner); let blackbox_supported = acvm::default_is_black_box_supported(np_language.clone()); - match create_circuit(program, np_language, blackbox_supported, show_ssa) { - Ok(circuit) => Ok(CompiledProgram { circuit, abi: Some(abi) }), - Err(err) => { - // The FileId here will be the file id of the file with the main file - // Errors will be shown at the call site without a stacktrace - let file = err.location.map(|loc| loc.file); - let files = &self.context.file_manager; - let error = reporter::report(files, &err.into(), file, allow_warnings); - reporter::finish_report(error as u32)?; - Err(ReportedError) - } - } + let (circuit, param_witnesses) = + match create_circuit(program, np_language, blackbox_supported, show_ssa) { + Ok((circuit, param_witnesses)) => (circuit, param_witnesses), + Err(err) => { + // The FileId here will be the file id of the file with the main file + // Errors will be shown at the call site without a stacktrace + let file = err.location.map(|loc| loc.file); + let files = &self.context.file_manager; + let error = reporter::report(files, &err.into(), file, allow_warnings); + reporter::finish_report(error as u32)?; + return Err(ReportedError); + } + }; + + Ok(CompiledProgram { circuit, abi: Some(abi), param_witnesses }) } /// Returns a list of all functions in the current crate marked with #[test] diff --git a/crates/noirc_evaluator/src/lib.rs b/crates/noirc_evaluator/src/lib.rs index ee71e2f1021..a05b94d001c 100644 --- a/crates/noirc_evaluator/src/lib.rs +++ b/crates/noirc_evaluator/src/lib.rs @@ -39,7 +39,7 @@ pub fn create_circuit( np_language: Language, is_blackbox_supported: IsBlackBoxSupported, enable_logging: bool, -) -> Result { +) -> Result<(Circuit, BTreeMap>), RuntimeError> { let mut evaluator = Evaluator::new(); // First evaluate the main function @@ -68,7 +68,7 @@ pub fn create_circuit( ) .map_err(|_| RuntimeErrorKind::Spanless(String::from("produced an acvm compile error")))?; - Ok(optimized_circuit) + Ok((optimized_circuit, evaluator.param_witnesses)) } impl Evaluator {