Skip to content

Commit

Permalink
chore: use explicit witness map for encoding/decoding witness map
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAFrench committed Feb 15, 2023
1 parent 2f68fe8 commit 73a083e
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 115 deletions.
89 changes: 52 additions & 37 deletions crates/nargo/src/cli/execute_cmd.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use clap::ArgMatches;
use std::collections::BTreeMap;
use std::path::{Path, PathBuf};

use acvm::acir::native_types::Witness;
use acvm::{FieldElement, PartialWitnessGenerator};
use clap::ArgMatches;
use iter_extended::{try_btree_map, vecmap};
use noirc_abi::errors::AbiError;
use noirc_abi::input_parser::{Format, InputValue};
use noirc_abi::{Abi, MAIN_RETURN_NAME};
use noirc_abi::{decode_value, encode_value, AbiParameter, MAIN_RETURN_NAME};
use noirc_driver::CompiledProgram;

use super::{create_named_dir, read_inputs_from_file, write_to_file, InputMap, WitnessMap};
Expand Down Expand Up @@ -40,10 +42,6 @@ pub(crate) fn run(args: ArgMatches) -> Result<(), CliError> {
Ok(())
}

/// In Barretenberg, the proof system adds a zero witness in the first index,
/// So when we add witness values, their index start from 1.
const WITNESS_OFFSET: u32 = 1;

fn execute_with_path<P: AsRef<Path>>(
program_dir: P,
show_ssa: bool,
Expand Down Expand Up @@ -75,30 +73,12 @@ pub(crate) fn execute_program(
Ok((return_value, solved_witness))
}

pub(crate) fn extract_public_inputs(
compiled_program: &CompiledProgram,
solved_witness: &WitnessMap,
) -> Result<InputMap, AbiError> {
let encoded_public_inputs: Vec<FieldElement> = compiled_program
.circuit
.public_inputs
.0
.iter()
.map(|index| solved_witness[index])
.collect();

let public_abi = compiled_program.abi.as_ref().unwrap().clone().public_abi();

public_abi.decode(&encoded_public_inputs)
}

pub(crate) fn solve_witness(
compiled_program: &CompiledProgram,
input_map: &InputMap,
) -> Result<WitnessMap, CliError> {
let abi = compiled_program.abi.as_ref().unwrap().clone();
let mut solved_witness =
input_map_to_witness_map(abi, input_map).map_err(|error| match error {
let mut solved_witness = input_map_to_witness_map(input_map, &compiled_program.param_witnesses)
.map_err(|error| match error {
AbiError::UndefinedInput(_) => {
CliError::Generic(format!("{error} in the {PROVER_INPUT_FILE}.toml file."))
}
Expand All @@ -115,18 +95,53 @@ pub(crate) fn solve_witness(
///
/// In particular, this method shows one how to associate values in a Toml/JSON
/// file with witness indices
fn input_map_to_witness_map(abi: Abi, input_map: &InputMap) -> Result<WitnessMap, AbiError> {
// The ABI map is first encoded as a vector of field elements
let encoded_inputs = abi.encode(input_map, true)?;

Ok(encoded_inputs
.into_iter()
.enumerate()
.map(|(index, witness_value)| {
let witness = Witness::new(WITNESS_OFFSET + (index as u32));
(witness, witness_value)
fn input_map_to_witness_map(
input_map: &InputMap,
abi_witness_map: &BTreeMap<String, Vec<Witness>>,
) -> Result<WitnessMap, AbiError> {
// First encode each input separately
let encoded_input_map: BTreeMap<String, Vec<FieldElement>> =
try_btree_map(input_map, |(key, value)| {
encode_value(value.clone(), key).map(|v| (key.clone(), v))
})?;

// Write input field elements into witness indices specified in `abi_witness_map`.
let witness_map = encoded_input_map
.iter()
.flat_map(|(param_name, encoded_param_fields)| {
let param_witness_indices = &abi_witness_map[param_name];
param_witness_indices
.iter()
.zip(encoded_param_fields.iter())
.map(|(&witness, &field_element)| (witness, field_element))
})
.collect();

Ok(witness_map)
}

pub(crate) fn extract_public_inputs(
compiled_program: &CompiledProgram,
solved_witness: &WitnessMap,
) -> Result<InputMap, AbiError> {
let public_abi = compiled_program.abi.as_ref().unwrap().clone().public_abi();

let public_inputs_map = public_abi
.parameters
.iter()
.map(|AbiParameter { name, typ, .. }| {
let param_witness_values =
vecmap(compiled_program.param_witnesses[name].clone(), |witness_index| {
solved_witness[&witness_index]
});

decode_value(&mut param_witness_values.into_iter(), typ)
.map(|input_value| (name.clone(), input_value))
.unwrap()
})
.collect())
.collect();

Ok(public_inputs_map)
}

pub(crate) fn save_witness_to_dir<P: AsRef<Path>>(
Expand Down
127 changes: 63 additions & 64 deletions crates/noirc_abi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ impl Abi {
return Err(AbiError::TypeMismatch { param: param.to_owned(), value });
}

encoded_inputs.extend(Self::encode_value(value, &param.name)?);
encoded_inputs.extend(encode_value(value, &param.name)?);
}

// Check that no extra witness values have been provided.
Expand All @@ -197,27 +197,6 @@ impl Abi {
Ok(encoded_inputs)
}

fn encode_value(value: InputValue, param_name: &String) -> Result<Vec<FieldElement>, AbiError> {
let mut encoded_value = Vec::new();
match value {
InputValue::Field(elem) => encoded_value.push(elem),
InputValue::Vec(vec_elem) => encoded_value.extend(vec_elem),
InputValue::String(string) => {
let str_as_fields =
string.bytes().map(|byte| FieldElement::from_be_bytes_reduce(&[byte]));
encoded_value.extend(str_as_fields)
}
InputValue::Struct(object) => {
for (field_name, value) in object {
let new_name = format!("{param_name}.{field_name}");
encoded_value.extend(Self::encode_value(value, &new_name)?)
}
}
InputValue::Undefined => return Err(AbiError::UndefinedInput(param_name.to_string())),
}
Ok(encoded_value)
}

/// Decode a vector of `FieldElements` into the types specified in the ABI.
pub fn decode(
&self,
Expand All @@ -234,59 +213,79 @@ impl Abi {
let mut field_iterator = encoded_inputs.iter().cloned();
let mut decoded_inputs = BTreeMap::new();
for param in &self.parameters {
let decoded_value = Self::decode_value(&mut field_iterator, &param.typ)?;
let decoded_value = decode_value(&mut field_iterator, &param.typ)?;

decoded_inputs.insert(param.name.to_owned(), decoded_value);
}
Ok(decoded_inputs)
}
}

fn decode_value(
field_iterator: &mut impl Iterator<Item = FieldElement>,
value_type: &AbiType,
) -> Result<InputValue, AbiError> {
// This function assumes that `field_iterator` contains enough `FieldElement`s in order to decode a `value_type`
// `Abi.decode` enforces that the encoded inputs matches the expected length defined by the ABI so this is safe.
let value = match value_type {
AbiType::Field | AbiType::Integer { .. } | AbiType::Boolean => {
let field_element = field_iterator.next().unwrap();

InputValue::Field(field_element)
pub fn encode_value(value: InputValue, param_name: &String) -> Result<Vec<FieldElement>, AbiError> {
let mut encoded_value = Vec::new();
match value {
InputValue::Field(elem) => encoded_value.push(elem),
InputValue::Vec(vec_elem) => encoded_value.extend(vec_elem),
InputValue::String(string) => {
let str_as_fields =
string.bytes().map(|byte| FieldElement::from_be_bytes_reduce(&[byte]));
encoded_value.extend(str_as_fields)
}
InputValue::Struct(object) => {
for (field_name, value) in object {
let new_name = format!("{param_name}.{field_name}");
encoded_value.extend(encode_value(value, &new_name)?)
}
AbiType::Array { length, .. } => {
let field_elements: Vec<FieldElement> =
field_iterator.take(*length as usize).collect();
}
InputValue::Undefined => return Err(AbiError::UndefinedInput(param_name.to_string())),
}
Ok(encoded_value)
}

InputValue::Vec(field_elements)
}
AbiType::String { length } => {
let string_as_slice: Vec<u8> = field_iterator
.take(*length as usize)
.map(|e| {
let mut field_as_bytes = e.to_be_bytes();
let char_byte = field_as_bytes.pop().unwrap(); // A character in a string is represented by a u8, thus we just want the last byte of the element
assert!(field_as_bytes.into_iter().all(|b| b == 0)); // Assert that the rest of the field element's bytes are empty
char_byte
})
.collect();

let final_string = str::from_utf8(&string_as_slice).unwrap();

InputValue::String(final_string.to_owned())
}
AbiType::Struct { fields, .. } => {
let mut struct_map = BTreeMap::new();
pub fn decode_value(
field_iterator: &mut impl Iterator<Item = FieldElement>,
value_type: &AbiType,
) -> Result<InputValue, AbiError> {
// This function assumes that `field_iterator` contains enough `FieldElement`s in order to decode a `value_type`
// `Abi.decode` enforces that the encoded inputs matches the expected length defined by the ABI so this is safe.
let value = match value_type {
AbiType::Field | AbiType::Integer { .. } | AbiType::Boolean => {
let field_element = field_iterator.next().unwrap();

InputValue::Field(field_element)
}
AbiType::Array { length, .. } => {
let field_elements: Vec<FieldElement> = field_iterator.take(*length as usize).collect();

for (field_key, param_type) in fields {
let field_value = Self::decode_value(field_iterator, param_type)?;
InputValue::Vec(field_elements)
}
AbiType::String { length } => {
let string_as_slice: Vec<u8> = field_iterator
.take(*length as usize)
.map(|e| {
let mut field_as_bytes = e.to_be_bytes();
let char_byte = field_as_bytes.pop().unwrap(); // A character in a string is represented by a u8, thus we just want the last byte of the element
assert!(field_as_bytes.into_iter().all(|b| b == 0)); // Assert that the rest of the field element's bytes are empty
char_byte
})
.collect();

let final_string = str::from_utf8(&string_as_slice).unwrap();

InputValue::String(final_string.to_owned())
}
AbiType::Struct { fields, .. } => {
let mut struct_map = BTreeMap::new();

struct_map.insert(field_key.to_owned(), field_value);
}
for (field_key, param_type) in fields {
let field_value = decode_value(field_iterator, param_type)?;

InputValue::Struct(struct_map)
struct_map.insert(field_key.to_owned(), field_value);
}
};

Ok(value)
}
InputValue::Struct(struct_map)
}
};

Ok(value)
}
30 changes: 18 additions & 12 deletions crates/noirc_driver/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![forbid(unsafe_code)]
use acvm::acir::circuit::Circuit;

use acvm::acir::native_types::Witness;
use acvm::Language;
use fm::FileType;
use noirc_abi::Abi;
Expand All @@ -12,6 +13,7 @@ use noirc_frontend::hir::Context;
use noirc_frontend::monomorphization::monomorphize;
use noirc_frontend::node_interner::FuncId;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::path::{Path, PathBuf};

pub struct Driver {
Expand All @@ -22,6 +24,7 @@ pub struct Driver {
pub struct CompiledProgram {
pub circuit: Circuit,
pub abi: Option<noirc_abi::Abi>,
pub param_witnesses: BTreeMap<String, Vec<Witness>>,
}

impl Driver {
Expand Down Expand Up @@ -187,18 +190,21 @@ impl Driver {
let program = monomorphize(main_function, &self.context.def_interner);

let blackbox_supported = acvm::default_is_black_box_supported(np_language.clone());
match create_circuit(program, np_language, blackbox_supported, show_ssa) {
Ok(circuit) => Ok(CompiledProgram { circuit, abi: Some(abi) }),
Err(err) => {
// The FileId here will be the file id of the file with the main file
// Errors will be shown at the call site without a stacktrace
let file = err.location.map(|loc| loc.file);
let files = &self.context.file_manager;
let error = reporter::report(files, &err.into(), file, allow_warnings);
reporter::finish_report(error as u32)?;
Err(ReportedError)
}
}
let (circuit, param_witnesses) =
match create_circuit(program, np_language, blackbox_supported, show_ssa) {
Ok((circuit, param_witnesses)) => (circuit, param_witnesses),
Err(err) => {
// The FileId here will be the file id of the file with the main file
// Errors will be shown at the call site without a stacktrace
let file = err.location.map(|loc| loc.file);
let files = &self.context.file_manager;
let error = reporter::report(files, &err.into(), file, allow_warnings);
reporter::finish_report(error as u32)?;
return Err(ReportedError);
}
};

Ok(CompiledProgram { circuit, abi: Some(abi), param_witnesses })
}

/// Returns a list of all functions in the current crate marked with #[test]
Expand Down
4 changes: 2 additions & 2 deletions crates/noirc_evaluator/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub fn create_circuit(
np_language: Language,
is_blackbox_supported: IsBlackBoxSupported,
enable_logging: bool,
) -> Result<Circuit, RuntimeError> {
) -> Result<(Circuit, BTreeMap<String, Vec<Witness>>), RuntimeError> {
let mut evaluator = Evaluator::new();

// First evaluate the main function
Expand Down Expand Up @@ -68,7 +68,7 @@ pub fn create_circuit(
)
.map_err(|_| RuntimeErrorKind::Spanless(String::from("produced an acvm compile error")))?;

Ok(optimized_circuit)
Ok((optimized_circuit, evaluator.param_witnesses))
}

impl Evaluator {
Expand Down

0 comments on commit 73a083e

Please sign in to comment.