From b122aec30f9ebd98a42466530033d5aa600dc307 Mon Sep 17 00:00:00 2001 From: guipublic Date: Mon, 6 Feb 2023 14:19:48 +0000 Subject: [PATCH 1/3] Directive for sorting networks --- .github/workflows/rust.yml | 11 + CHANGELOG.md | 2 +- README.md | 4 +- ...ox_functions.rs => black_box_functions.rs} | 0 acir/src/circuit/directives.rs | 90 +++- acir/src/circuit/mod.rs | 12 +- acir/src/circuit/opcodes.rs | 39 +- acir/src/lib.rs | 4 +- acir/src/native_types/arithmetic.rs | 8 +- acir/src/native_types/witness.rs | 2 +- .../{serialisation.rs => serialization.rs} | 2 +- acir_field/src/generic_ark.rs | 6 +- acvm/Cargo.toml | 1 + acvm/src/compiler.rs | 40 +- acvm/src/compiler/optimiser/mod.rs | 7 - .../csat_optimizer.rs} | 84 ++-- .../general_optimizer.rs} | 6 +- acvm/src/compiler/optimizer/mod.rs | 7 + .../r1cs_optimizer.rs} | 18 +- .../range_optimizer.rs} | 16 +- acvm/src/lib.rs | 27 +- acvm/src/pwg.rs | 3 +- acvm/src/pwg/arithmetic.rs | 96 +++-- acvm/src/pwg/directives.rs | 154 ++++++- acvm/src/pwg/sorting.rs | 390 ++++++++++++++++++ cspell.json | 44 ++ stdlib/src/fallback.rs | 4 +- stdlib/src/helpers.rs | 2 +- 28 files changed, 898 insertions(+), 181 deletions(-) rename acir/src/circuit/{blackbox_functions.rs => black_box_functions.rs} (100%) rename acir/src/{serialisation.rs => serialization.rs} (95%) delete mode 100644 acvm/src/compiler/optimiser/mod.rs rename acvm/src/compiler/{optimiser/csat_optimiser.rs => optimizer/csat_optimizer.rs} (88%) rename acvm/src/compiler/{optimiser/general_optimiser.rs => optimizer/general_optimizer.rs} (87%) create mode 100644 acvm/src/compiler/optimizer/mod.rs rename acvm/src/compiler/{optimiser/r1cs_optimiser.rs => optimizer/r1cs_optimizer.rs} (70%) rename acvm/src/compiler/{optimiser/range_optimiser.rs => optimizer/range_optimizer.rs} (54%) create mode 100644 acvm/src/pwg/sorting.rs create mode 100644 cspell.json diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index d1ea0251b..935d6e471 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -22,3 +22,14 @@ jobs: run: cargo clippy --verbose - name: Run tests run: cargo test --verbose + + spellcheck: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: streetsidesoftware/cspell-action@v2 + with: + files: | + **/*.{md,rs} + incremental_files_only : true # Run this action on files which have changed in PR + strict: false # Do not fail, if a spelling mistake is found (This can be annoying for contributors) diff --git a/CHANGELOG.md b/CHANGELOG.md index e64cad60c..c6b4f98fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,7 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - XOR, Range and AND gates are no longer special case. They are now another opcode in the GadgetCall - Move fallback module to `stdlib` -- optimiser code and any other passes will live in acvm. acir is solely for defining the IR now. +- Optimizer code and any other passes will live in acvm. acir is solely for defining the IR now. - ACIR passes now live under the compiler parent module - Moved opcode module in acir crate to circuit/opcode - Rename GadgetCall to BlackBoxFuncCall diff --git a/README.md b/README.md index 7b38a27a0..49154bdc3 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # ACIR - Abstract Circuit Intermediate Representation -ACIR is an NP complete language that generalises R1CS and arithmetic circuits while not losing proving system specific optimisations through the use of black box functions. +ACIR is an NP complete language that generalizes R1CS and arithmetic circuits while not losing proving system specific optimizations through the use of black box functions. # ACVM - Abstract Circuit Virtual Machine This can be seen as the ACIR compiler. It will take an ACIR instance and convert it to the format required -by a particular proving system to create a proof. \ No newline at end of file +by a particular proving system to create a proof. diff --git a/acir/src/circuit/blackbox_functions.rs b/acir/src/circuit/black_box_functions.rs similarity index 100% rename from acir/src/circuit/blackbox_functions.rs rename to acir/src/circuit/black_box_functions.rs diff --git a/acir/src/circuit/directives.rs b/acir/src/circuit/directives.rs index c3d8e2cd0..585cd6a9b 100644 --- a/acir/src/circuit/directives.rs +++ b/acir/src/circuit/directives.rs @@ -2,7 +2,7 @@ use std::io::{Read, Write}; use crate::{ native_types::{Expression, Witness}, - serialisation::{read_n, read_u16, read_u32, write_bytes, write_u16, write_u32}, + serialization::{read_n, read_u16, read_u32, write_bytes, write_u16, write_u32}, }; use serde::{Deserialize, Serialize}; @@ -48,6 +48,16 @@ pub enum Directive { b: Vec, radix: u32, }, + + // Sort directive, using a sorting network + // This directive is used to generate the values of the control bits for the sorting network such that its ouputs are properly sorted accroding to sort_by + PermutationSort { + inputs: Vec>, // Array of tuples to sort + tuple: u32, // tuple size; if 1 then inputs is a single array [a0,a1,..], if 2 then inputs=[(a0,b0),..] is [a0,b0,a1,b1,..], etc.. + bits: Vec, // control bits of the network which permutes the inputs into its sorted version + sort_by: Vec, // specify primary index to sort by, then the secondary,... For instance, if typle is 2 and sort_by is [1,0], then a=[(a0,b0),..] is sorted by bi and then ai. + }, + Log(LogInfo), } impl Directive { @@ -58,6 +68,8 @@ impl Directive { Directive::Truncate { .. } => "truncate", Directive::OddRange { .. } => "odd_range", Directive::ToRadix { .. } => "to_radix", + Directive::PermutationSort { .. } => "permutation_sort", + Directive::Log { .. } => "log", } } fn to_u16(&self) -> u16 { @@ -67,6 +79,8 @@ impl Directive { Directive::Truncate { .. } => 2, Directive::OddRange { .. } => 3, Directive::ToRadix { .. } => 4, + Directive::Log { .. } => 5, + Directive::PermutationSort { .. } => 6, } } @@ -116,6 +130,39 @@ impl Directive { } write_u32(&mut writer, *radix)?; } + Directive::PermutationSort { + inputs: a, + tuple, + bits, + sort_by, + } => { + write_u32(&mut writer, *tuple)?; + write_u32(&mut writer, a.len() as u32)?; + for e in a { + for i in 0..*tuple { + e[i as usize].write(&mut writer)?; + } + } + write_u32(&mut writer, bits.len() as u32)?; + for b in bits { + write_u32(&mut writer, b.witness_index())?; + } + write_u32(&mut writer, sort_by.len() as u32)?; + for i in sort_by { + write_u32(&mut writer, *i)?; + } + } + Directive::Log(info) => match info { + LogInfo::FinalizedOutput(output_string) => { + write_bytes(&mut writer, output_string.as_bytes())?; + } + LogInfo::WitnessOutput(witnesses) => { + write_u32(&mut writer, witnesses.len() as u32)?; + for w in witnesses { + write_u32(&mut writer, w.witness_index())?; + } + } + }, }; Ok(()) @@ -178,14 +225,53 @@ impl Directive { Ok(Directive::ToRadix { a, b, radix }) } + 6 => { + let tuple = read_u32(&mut reader)?; + let a_len = read_u32(&mut reader)?; + let mut a = Vec::with_capacity(a_len as usize); + for _ in 0..a_len { + let mut element = Vec::new(); + for _ in 0..tuple { + element.push(Expression::read(&mut reader)?); + } + a.push(element); + } + + let bits_len = read_u32(&mut reader)?; + let mut bits = Vec::with_capacity(bits_len as usize); + for _ in 0..bits_len { + bits.push(Witness(read_u32(&mut reader)?)); + } + let sort_by_len = read_u32(&mut reader)?; + let mut sort_by = Vec::with_capacity(sort_by_len as usize); + for _ in 0..sort_by_len { + sort_by.push(read_u32(&mut reader)?); + } + Ok(Directive::PermutationSort { + inputs: a, + tuple, + bits, + sort_by, + }) + } _ => Err(std::io::ErrorKind::InvalidData.into()), } } } +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +// If values are compile time and/or known during +// evaluation, we can form an output string during ACIR generation. +// Otherwise, we must store witnesses whose values will +// be fetched during the PWG stage. +pub enum LogInfo { + FinalizedOutput(String), + WitnessOutput(Vec), +} + #[test] -fn serialisation_roundtrip() { +fn serialization_roundtrip() { fn read_write(directive: Directive) -> (Directive, Directive) { let mut bytes = Vec::new(); directive.write(&mut bytes).unwrap(); diff --git a/acir/src/circuit/mod.rs b/acir/src/circuit/mod.rs index 5d7c8260e..a6b0fdb07 100644 --- a/acir/src/circuit/mod.rs +++ b/acir/src/circuit/mod.rs @@ -1,10 +1,10 @@ -pub mod blackbox_functions; +pub mod black_box_functions; pub mod directives; pub mod opcodes; pub use opcodes::Opcode; use crate::native_types::Witness; -use crate::serialisation::{read_u32, write_u32}; +use crate::serialization::{read_u32, write_u32}; use rmp_serde; use serde::{Deserialize, Serialize}; @@ -27,7 +27,7 @@ impl Circuit { } #[deprecated( - note = "we want to use a serialisation strategy that is easy to implement in many languages (without ffi). use `read` instead" + note = "we want to use a serialization strategy that is easy to implement in many languages (without ffi). use `read` instead" )] pub fn from_bytes(bytes: &[u8]) -> Circuit { let mut deflater = DeflateDecoder::new(bytes); @@ -37,7 +37,7 @@ impl Circuit { } #[deprecated( - note = "we want to use a serialisation strategy that is easy to implement in many languages (without ffi).use `write` instead" + note = "we want to use a serialization strategy that is easy to implement in many languages (without ffi).use `write` instead" )] pub fn to_bytes(&self) -> Vec { let buf = rmp_serde::to_vec(&self).unwrap(); @@ -69,7 +69,7 @@ impl Circuit { // TODO (Note): we could use semver versioning from the Cargo.toml // here and then reject anything that has a major bump // - // We may also not want to do that if we do not want to couple serialisation + // We may also not want to do that if we do not want to couple serialization // with other breaking changes if version_number != VERSION_NUMBER { return Err(std::io::ErrorKind::InvalidData.into()); @@ -178,7 +178,7 @@ mod test { } #[test] - fn serialisation_roundtrip() { + fn serialization_roundtrip() { let circuit = Circuit { current_witness_index: 5, opcodes: vec![and_opcode(), range_opcode()], diff --git a/acir/src/circuit/opcodes.rs b/acir/src/circuit/opcodes.rs index 37017adde..8d6794387 100644 --- a/acir/src/circuit/opcodes.rs +++ b/acir/src/circuit/opcodes.rs @@ -1,8 +1,8 @@ use std::io::{Read, Write}; -use super::directives::Directive; +use super::directives::{Directive, LogInfo}; use crate::native_types::{Expression, Witness}; -use crate::serialisation::{read_n, read_u16, read_u32, write_bytes, write_u16, write_u32}; +use crate::serialization::{read_n, read_u16, read_u32, write_bytes, write_u16, write_u32}; use crate::BlackBoxFunc; use serde::{Deserialize, Serialize}; @@ -25,7 +25,7 @@ impl Opcode { } // We have three types of opcodes allowed in the IR // Expression, BlackBoxFuncCall and Directives - // When we serialise these opcodes, we use the index + // When we serialize these opcodes, we use the index // to uniquely identify which category of opcode we are dealing with. pub(crate) fn to_index(&self) -> u8 { match self { @@ -164,6 +164,33 @@ impl std::fmt::Display for Opcode { b.last().unwrap().witness_index(), ) } + Opcode::Directive(Directive::PermutationSort { + inputs: a, + tuple, + bits, + sort_by, + }) => { + write!(f, "DIR::PERMUTATIONSORT ")?; + write!( + f, + "(permutation size: {} {}-tuples, sort_by: {:#?}, bits: [_{}..._{}]))", + a.len(), + tuple, + sort_by, + // (Note): the bits do not have contiguous index but there are too many for display + bits.first().unwrap().witness_index(), + bits.last().unwrap().witness_index(), + ) + } + Opcode::Directive(Directive::Log(info)) => match info { + LogInfo::FinalizedOutput(output_string) => write!(f, "Log: {output_string}"), + LogInfo::WitnessOutput(witnesses) => write!( + f, + "Log: _{}..._{}", + witnesses.first().unwrap().witness_index(), + witnesses.last().unwrap().witness_index() + ), + }, } } } @@ -326,7 +353,7 @@ impl std::fmt::Debug for BlackBoxFuncCall { } #[test] -fn serialisation_roundtrip() { +fn serialization_roundtrip() { fn read_write(opcode: Opcode) -> (Opcode, Opcode) { let mut bytes = Vec::new(); opcode.write(&mut bytes).unwrap(); @@ -336,7 +363,7 @@ fn serialisation_roundtrip() { let opcode_arith = Opcode::Arithmetic(Expression::default()); - let opcode_blackbox_func = Opcode::BlackBoxFuncCall(BlackBoxFuncCall { + let opcode_black_box_func = Opcode::BlackBoxFuncCall(BlackBoxFuncCall { name: BlackBoxFunc::AES, inputs: vec![ FunctionInput { @@ -356,7 +383,7 @@ fn serialisation_roundtrip() { result: Witness(56789u32), }); - let opcodes = vec![opcode_arith, opcode_blackbox_func, opcode_directive]; + let opcodes = vec![opcode_arith, opcode_black_box_func, opcode_directive]; for opcode in opcodes { let (op, got_op) = read_write(opcode); diff --git a/acir/src/lib.rs b/acir/src/lib.rs index ff1f4a951..303341f5e 100644 --- a/acir/src/lib.rs +++ b/acir/src/lib.rs @@ -2,7 +2,7 @@ pub mod circuit; pub mod native_types; -mod serialisation; +mod serialization; pub use acir_field::FieldElement; -pub use circuit::blackbox_functions::BlackBoxFunc; +pub use circuit::black_box_functions::BlackBoxFunc; diff --git a/acir/src/native_types/arithmetic.rs b/acir/src/native_types/arithmetic.rs index 56621919a..342a70819 100644 --- a/acir/src/native_types/arithmetic.rs +++ b/acir/src/native_types/arithmetic.rs @@ -1,5 +1,5 @@ use crate::native_types::{Linear, Witness}; -use crate::serialisation::{read_field_element, read_u32, write_bytes, write_u32}; +use crate::serialization::{read_field_element, read_u32, write_bytes, write_u32}; use acir_field::FieldElement; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; @@ -17,7 +17,7 @@ use super::witness::UnknownWitness; // XXX: If we allow the degree of the quotient polynomial to be arbitrary, then we will need a vector of wire values #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct Expression { - // To avoid having to create intermediate variables pre-optimisation + // To avoid having to create intermediate variables pre-optimization // We collect all of the multiplication terms in the arithmetic gate // A multiplication term if of the form q_M * wL * wR // Hence this vector represents the following sum: q_M1 * wL1 * wR1 + q_M2 * wL2 * wR2 + .. + @@ -448,7 +448,7 @@ impl Expression { // A polynomial whose mul terms are non zero which do not match up with two terms in the fan-in cannot fit into one gate // An example of this is: Axy + Bx + Cy + ... // Notice how the bivariate monomial xy has two univariate monomials with their respective coefficients - // XXX: note that if x or y is zero, then we could apply a further optimisation, but this would be done in another algorithm. + // XXX: note that if x or y is zero, then we could apply a further optimization, but this would be done in another algorithm. // It would be the same as when we have zero coefficients - Can only work if wire is constrained to be zero publicly let mul_term = &self.mul_terms[0]; @@ -478,7 +478,7 @@ impl Expression { } #[test] -fn serialisation_roundtrip() { +fn serialization_roundtrip() { // Empty expression // let expr = Expression::default(); diff --git a/acir/src/native_types/witness.rs b/acir/src/native_types/witness.rs index 8bcc62a72..d7753ec7c 100644 --- a/acir/src/native_types/witness.rs +++ b/acir/src/native_types/witness.rs @@ -55,7 +55,7 @@ impl Witness { // We use this, so that they are pushed to the beginning of the array // // When they are pushed to the beginning of the array, they are less likely to be used in an intermediate gate -// by the optimiser, which would mean two unknowns in an equation. +// by the optimizer, which would mean two unknowns in an equation. // See Issue #20 // TODO: can we find a better solution to this? pub struct UnknownWitness(pub u32); diff --git a/acir/src/serialisation.rs b/acir/src/serialization.rs similarity index 95% rename from acir/src/serialisation.rs rename to acir/src/serialization.rs index 838e01667..1fe719936 100644 --- a/acir/src/serialisation.rs +++ b/acir/src/serialization.rs @@ -43,7 +43,7 @@ pub fn read_field_element( let bytes = read_n::(&mut r)?; - // TODO: We should not reduce here, we want the serialisation to be + // TODO: We should not reduce here, we want the serialization to be // TODO canonical let field_element = FieldElement::from_be_bytes_reduce(&bytes); diff --git a/acir_field/src/generic_ark.rs b/acir_field/src/generic_ark.rs index fb3722ea7..76518d3ca 100644 --- a/acir_field/src/generic_ark.rs +++ b/acir_field/src/generic_ark.rs @@ -51,7 +51,7 @@ impl std::fmt::Display for FieldElement { // we usually have numbers in the form 2^t * q + r // We focus on 2^64, 2^32, 2^16, 2^8, 2^4 because // they are common. We could extend this to a more - // general factorisation strategy, but we pay in terms of CPU time + // general factorization strategy, but we pay in terms of CPU time let mul_sign = "×"; for power in [64, 32, 16, 8, 4] { let power_of_two = BigUint::from(2_u128).pow(power); @@ -226,7 +226,7 @@ impl FieldElement { } /// Computes the inverse or returns zero if the inverse does not exist - /// Before using this FieldElement, please ensure that this behaviour is necessary + /// Before using this FieldElement, please ensure that this behavior is necessary pub fn inverse(&self) -> FieldElement { let inv = self.0.inverse().unwrap_or_else(F::zero); FieldElement(inv) @@ -297,7 +297,7 @@ impl FieldElement { let num_elements = num_bytes / 8; let mut bytes = self.to_be_bytes(); - bytes.reverse(); // put it in big endian format. XXX(next refactor): we should be explicit about endianess. + bytes.reverse(); // put it in big endian format. XXX(next refactor): we should be explicit about endianness. bytes[0..num_elements].to_vec() } diff --git a/acvm/Cargo.toml b/acvm/Cargo.toml index cbd53d0fd..f6715e274 100644 --- a/acvm/Cargo.toml +++ b/acvm/Cargo.toml @@ -33,3 +33,4 @@ bls12_381 = ["acir_field/bls12_381"] [dev-dependencies] tempfile = "3.2.0" +rand="0.8.5" \ No newline at end of file diff --git a/acvm/src/compiler.rs b/acvm/src/compiler.rs index 15df5f1ed..412921c77 100644 --- a/acvm/src/compiler.rs +++ b/acvm/src/compiler.rs @@ -1,6 +1,6 @@ // The various passes that we can use over ACIR pub mod fallback; -pub mod optimiser; +pub mod optimizer; use crate::Language; use acir::{ @@ -9,10 +9,10 @@ use acir::{ BlackBoxFunc, }; use indexmap::IndexMap; -use optimiser::{CSatOptimiser, GeneralOptimiser}; +use optimizer::{CSatOptimizer, GeneralOptimizer}; use thiserror::Error; -use self::{fallback::IsBlackBoxSupported, optimiser::R1CSOptimiser}; +use self::{fallback::IsBlackBoxSupported, optimizer::R1CSOptimizer}; #[derive(PartialEq, Eq, Debug, Error)] pub enum CompileError { @@ -23,30 +23,30 @@ pub enum CompileError { pub fn compile( acir: Circuit, np_language: Language, - is_blackbox_supported: IsBlackBoxSupported, + is_black_box_supported: IsBlackBoxSupported, ) -> Result { - // Instantiate the optimiser. - // Currently the optimiser and reducer are one in the same + // Instantiate the optimizer. + // Currently the optimizer and reducer are one in the same // for CSAT // Fallback pass - let fallback = fallback::fallback(acir, is_blackbox_supported)?; + let fallback = fallback::fallback(acir, is_black_box_supported)?; - let optimiser = match &np_language { + let optimizer = match &np_language { crate::Language::R1CS => { - let optimiser = R1CSOptimiser::new(fallback); - return Ok(optimiser.optimise()); + let optimizer = R1CSOptimizer::new(fallback); + return Ok(optimizer.optimize()); } - crate::Language::PLONKCSat { width } => CSatOptimiser::new(*width), + crate::Language::PLONKCSat { width } => CSatOptimizer::new(*width), }; - // TODO: the code below is only for CSAT optimiser + // TODO: the code below is only for CSAT optimizer // TODO it may be possible to refactor it in a way that we do not need to return early from the r1cs - // TODO or at the very least, we could put all of it inside of CSATOptimiser pass + // TODO or at the very least, we could put all of it inside of CSatOptimizer pass - // Optimise the arithmetic gates by reducing them into the correct width and + // Optimize the arithmetic gates by reducing them into the correct width and // creating intermediate variables when necessary - let mut optimised_gates = Vec::new(); + let mut optimized_gates = Vec::new(); let mut next_witness_index = fallback.current_witness_index + 1; for opcode in fallback.opcodes { @@ -55,7 +55,7 @@ pub fn compile( let mut intermediate_variables: IndexMap = IndexMap::new(); let arith_expr = - optimiser.optimise(arith_expr, &mut intermediate_variables, next_witness_index); + optimizer.optimize(arith_expr, &mut intermediate_variables, next_witness_index); // Update next_witness counter next_witness_index += intermediate_variables.len() as u32; @@ -67,10 +67,10 @@ pub fn compile( new_gates.push(arith_expr); new_gates.sort(); for gate in new_gates { - optimised_gates.push(Opcode::Arithmetic(gate)); + optimized_gates.push(Opcode::Arithmetic(gate)); } } - other_gate => optimised_gates.push(other_gate), + other_gate => optimized_gates.push(other_gate), } } @@ -78,7 +78,7 @@ pub fn compile( Ok(Circuit { current_witness_index, - opcodes: optimised_gates, - public_inputs: fallback.public_inputs, // The optimiser does not add public inputs + opcodes: optimized_gates, + public_inputs: fallback.public_inputs, // The optimizer does not add public inputs }) } diff --git a/acvm/src/compiler/optimiser/mod.rs b/acvm/src/compiler/optimiser/mod.rs deleted file mode 100644 index 02b03c0bc..000000000 --- a/acvm/src/compiler/optimiser/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -mod csat_optimiser; -mod general_optimiser; -mod r1cs_optimiser; - -pub use csat_optimiser::Optimiser as CSatOptimiser; -pub use general_optimiser::GeneralOpt as GeneralOptimiser; -pub use r1cs_optimiser::R1CSOptimiser; diff --git a/acvm/src/compiler/optimiser/csat_optimiser.rs b/acvm/src/compiler/optimizer/csat_optimizer.rs similarity index 88% rename from acvm/src/compiler/optimiser/csat_optimiser.rs rename to acvm/src/compiler/optimizer/csat_optimizer.rs index 94126b89f..454bcde42 100644 --- a/acvm/src/compiler/optimiser/csat_optimiser.rs +++ b/acvm/src/compiler/optimizer/csat_optimizer.rs @@ -6,53 +6,53 @@ use acir::{ }; use indexmap::IndexMap; -use super::general_optimiser::GeneralOpt; -// Optimiser struct with all of the related optimisations to the arithmetic gate +use super::general_optimizer::GeneralOpt; +// Optimizer struct with all of the related optimizations to the arithmetic gate -// Is this more of a Reducer than an optimiser? +// Is this more of a Reducer than an optimizer? // Should we give it all of the gates? -// Have a single optimiser that you instantiate with a width, then pass many gates through -pub struct Optimiser { +// Have a single optimizer that you instantiate with a width, then pass many gates through +pub struct Optimizer { width: usize, } -impl Optimiser { - // Configure the width for the optimiser - pub fn new(width: usize) -> Optimiser { +impl Optimizer { + // Configure the width for the Optimizer + pub fn new(width: usize) -> Optimizer { assert!(width > 2); - Optimiser { width } + Optimizer { width } } - // Still missing dead witness optimisation. + // Still missing dead witness optimization. // To do this, we will need the whole set of arithmetic gates - // I think it can also be done before the local optimisation seen here, as dead variables will come from the user - pub fn optimise( + // I think it can also be done before the local optimization seen here, as dead variables will come from the user + pub fn optimize( &self, gate: Expression, intermediate_variables: &mut IndexMap, num_witness: u32, ) -> Expression { - let gate = GeneralOpt::optimise(gate); + let gate = GeneralOpt::optimize(gate); // Here we create intermediate variables and constrain them to be equal to any subset of the polynomial that can be represented as a full gate - let gate = self.full_gate_scan_optimisation(gate, intermediate_variables, num_witness); - // The last optimisation to do is to create intermediate variables in order to flatten the fan-in and the amount of mul terms + let gate = self.full_gate_scan_optimization(gate, intermediate_variables, num_witness); + // The last optimization to do is to create intermediate variables in order to flatten the fan-in and the amount of mul terms // If a gate has more than one mul term. We may need an intermediate variable for each one. Since not every variable will need to link to // the mul term, we could possibly do it that way. - // We wil call this a partial gate scan optimisation which will result in the gates being able to fit into the correct width + // We wil call this a partial gate scan optimization which will result in the gates being able to fit into the correct width let mut gate = - self.partial_gate_scan_optimisation(gate, intermediate_variables, num_witness); + self.partial_gate_scan_optimization(gate, intermediate_variables, num_witness); gate.sort(); gate } - // This optimisation will search for combinations of terms which can be represented in a single arithmetic gate + // This optimization will search for combinations of terms which can be represented in a single arithmetic gate // Case 1 : qM * wL * wR + qL * wL + qR * wR + qO * wO + qC - // This polynomial does not require any further optimisations, it can be safely represented in one gate + // This polynomial does not require any further optimizations, it can be safely represented in one gate // ie a polynomial with 1 mul(bi-variate) term and 3 (univariate) terms where 2 of those terms match the bivariate term // wL and wR, we can represent it in one gate - // GENERALISED for WIDTH: instead of the number 3, we use `WIDTH` + // GENERALIZED for WIDTH: instead of the number 3, we use `WIDTH` // // // Case 2: qM * wL * wR + qL * wL + qR * wR + qO * wO + qC + qM2 * wL2 * wR2 + qL * wL2 + qR * wR2 + qO * wO2 + qC2 @@ -70,7 +70,7 @@ impl Optimiser { // The polynomial now looks like so t + t2 // We can no longer extract another full gate, hence the algorithm terminates. Creating two intermediate variables t and t2. // This stage of preprocessing does not guarantee that all polynomials can fit into a gate. It only guarantees that all full gates have been extracted from each polynomial - fn full_gate_scan_optimisation( + fn full_gate_scan_optimization( &self, mut gate: Expression, intermediate_variables: &mut IndexMap, @@ -79,11 +79,11 @@ impl Optimiser { // We pass around this intermediate variable IndexMap, so that we do not create intermediate variables that we have created before // One instance where this might happen is t1 = wL * wR and t2 = wR * wL - // First check that this is not a simple gate which does not need optimisation + // First check that this is not a simple gate which does not need optimization // - // If the gate only has one mul term, then this algorithm cannot optimise it any further + // If the gate only has one mul term, then this algorithm cannot optimize it any further // Either it can be represented in a single arithmetic equation or it's fan-in is too large and we need intermediate variables for those - // large-fan-in optimisation is not this algorithms purpose. + // large-fan-in optimization is not this algorithms purpose. // If the gate has 0 mul terms, then it is an add gate and similarly it can either fit into a single arithmetic gate or it has a large fan-in if gate.mul_terms.len() <= 1 { return gate; @@ -103,7 +103,7 @@ impl Optimiser { // Check if this pair is present in the simplified fan-in // We are assuming that the fan-in/fan-out has been simplified. - // Note this function is not public, and can only be called within the optimise method, so this guarantee will always hold + // Note this function is not public, and can only be called within the optimize method, so this guarantee will always hold let index_wl = gate .linear_combinations .iter() @@ -162,15 +162,15 @@ impl Optimiser { // Add this element into the new gate intermediate_gate.linear_combinations.push(wire_term); } else { - // Nomore elements left in the old gate, we could stop the whole function - // We could alternative let it keep going, as it will never reach this branch again since there are nomore elements left - // XXX: Future optimisation - // nomoreleft = true + // No more elements left in the old gate, we could stop the whole function + // We could alternative let it keep going, as it will never reach this branch again since there are no more elements left + // XXX: Future optimization + // no_more_left = true } } // Constraint this intermediate_gate to be equal to the temp variable by adding it into the IndexMap // We need a unique name for our intermediate variable - // XXX: Another optimisation, which could be applied in another algorithm + // XXX: Another optimization, which could be applied in another algorithm // If two gates have a large fan-in/out and they share a few common terms, then we should create intermediate variables for them // Do some sort of subset matching algorithm for this on the terms of the polynomial let inter_var = Witness(intermediate_variables.len() as u32 + num_witness); @@ -202,16 +202,16 @@ impl Optimiser { new_gate } - // A partial gate scan optimisation aim to create intermediate variables in order to compress the polynomial + // A partial gate scan optimization aim to create intermediate variables in order to compress the polynomial // So that it fits within the given width - // Note that this gate follows the full gate scan optimisation. + // Note that this gate follows the full gate scan optimization. // We define the partial width as equal to the full width - 2. // This is because two of our variables cannot be used as they are linked to the multiplication terms // Example: qM1 * wL1 * wR2 + qL1 * wL3 + qR1 * wR4+ qR2 * wR5 + qO1 * wO5 + qC // One thing to note is that the multiplication wires do not match any of the fan-in/out wires. This is guaranteed as we have - // just completed the full gate optimisation algorithm. + // just completed the full gate optimization algorithm. // - //Actually we can optimise in two ways here: We can create an intermediate variable which is equal to the fan-in terms + //Actually we can optimize in two ways here: We can create an intermediate variable which is equal to the fan-in terms // t = qL1 * wL3 + qR1 * wR4 -> width = 3 // This `t` value can only use width - 1 terms // The gate now looks like: qM1 * wL1 * wR2 + t + qR2 * wR5+ qO1 * wO5 + qC @@ -234,12 +234,12 @@ impl Optimiser { // The gate now looks like: t2 + qR1 * wR4+ qR2 * wR5 + qO1 * wO5 + qC // t3 = t2 + qR1 * wR4 // The gate now looks like: t3 + qR2 * wR5 + qO1 * wO5 + qC - // This took the same amount of gates, but which one is better when the width increases? Compute this and maybe do both optimisations + // This took the same amount of gates, but which one is better when the width increases? Compute this and maybe do both optimizations // naming : partial_gate_mul_first_opt and partial_gate_fan_first_opt // Also remember that since we did full gate scan, there is no way we can have a non-zero mul term along with the wL and wR terms being non-zero // // Cases, a lot of mul terms, a lot of fan-in terms, 50/50 - fn partial_gate_scan_optimisation( + fn partial_gate_scan_optimization( &self, mut gate: Expression, intermediate_variables: &mut IndexMap, @@ -248,7 +248,7 @@ impl Optimiser { // We will go for the easiest route, which is to convert all multiplications into additions using intermediate variables // Then use intermediate variables again to squash the fan-in, so that it can fit into the appropriate width - // First check if this polynomial actually needs a partial gate optimisation + // First check if this polynomial actually needs a partial gate optimization // There is the chance that it fits perfectly within the arithmetic gate if gate.fits_in_one_identity(self.width) { return gate; @@ -321,7 +321,7 @@ impl Optimiser { // keep consistency with the original equation. gate.linear_combinations.extend(added); - self.partial_gate_scan_optimisation(gate, intermediate_variables, num_witness) + self.partial_gate_scan_optimization(gate, intermediate_variables, num_witness) } } @@ -348,8 +348,8 @@ fn simple_reduction_smoke_test() { let num_witness = 4; - let optimiser = Optimiser::new(3); - let got_optimised_gate_a = optimiser.optimise(gate_a, &mut intermediate_variables, num_witness); + let optimizer = Optimizer::new(3); + let got_optimized_gate_a = optimizer.optimize(gate_a, &mut intermediate_variables, num_witness); // a = b + c + d => a - b - c - d = 0 // For width3, the result becomes: @@ -358,7 +358,7 @@ fn simple_reduction_smoke_test() { // // a - b + e = 0 let e = Witness(4); - let expected_optimised_gate_a = Expression { + let expected_optimized_gate_a = Expression { mul_terms: vec![], linear_combinations: vec![ (FieldElement::one(), a), @@ -367,7 +367,7 @@ fn simple_reduction_smoke_test() { ], q_c: FieldElement::zero(), }; - assert_eq!(expected_optimised_gate_a, got_optimised_gate_a); + assert_eq!(expected_optimized_gate_a, got_optimized_gate_a); assert_eq!(intermediate_variables.len(), 1); diff --git a/acvm/src/compiler/optimiser/general_optimiser.rs b/acvm/src/compiler/optimizer/general_optimizer.rs similarity index 87% rename from acvm/src/compiler/optimiser/general_optimiser.rs rename to acvm/src/compiler/optimizer/general_optimizer.rs index 65e95e13b..6f81c896e 100644 --- a/acvm/src/compiler/optimiser/general_optimiser.rs +++ b/acvm/src/compiler/optimizer/general_optimizer.rs @@ -6,8 +6,8 @@ use indexmap::IndexMap; pub struct GeneralOpt; impl GeneralOpt { - pub fn optimise(gate: Expression) -> Expression { - // XXX: Perhaps this optimisation can be done on the fly + pub fn optimize(gate: Expression) -> Expression { + // XXX: Perhaps this optimization can be done on the fly let gate = remove_zero_coefficients(gate); simplify_mul_terms(gate) } @@ -27,7 +27,7 @@ pub fn remove_zero_coefficients(mut gate: Expression) -> Expression { pub fn simplify_mul_terms(mut gate: Expression) -> Expression { let mut hash_map: IndexMap<(Witness, Witness), FieldElement> = IndexMap::new(); - // Canonicalise the ordering of the multiplication, lets just order by variable name + // Canonicalize the ordering of the multiplication, lets just order by variable name for (scale, w_l, w_r) in gate.mul_terms.clone().into_iter() { let mut pair = vec![w_l, w_r]; // Sort using rust sort algorithm diff --git a/acvm/src/compiler/optimizer/mod.rs b/acvm/src/compiler/optimizer/mod.rs new file mode 100644 index 000000000..e84d7cb79 --- /dev/null +++ b/acvm/src/compiler/optimizer/mod.rs @@ -0,0 +1,7 @@ +mod csat_optimizer; +mod general_optimizer; +mod r1cs_optimizer; + +pub use csat_optimizer::Optimizer as CSatOptimizer; +pub use general_optimizer::GeneralOpt as GeneralOptimizer; +pub use r1cs_optimizer::R1CSOptimizer; diff --git a/acvm/src/compiler/optimiser/r1cs_optimiser.rs b/acvm/src/compiler/optimizer/r1cs_optimizer.rs similarity index 70% rename from acvm/src/compiler/optimiser/r1cs_optimiser.rs rename to acvm/src/compiler/optimizer/r1cs_optimizer.rs index f76562164..01ecaffa1 100644 --- a/acvm/src/compiler/optimiser/r1cs_optimiser.rs +++ b/acvm/src/compiler/optimizer/r1cs_optimizer.rs @@ -1,34 +1,34 @@ -use crate::compiler::GeneralOptimiser; +use crate::compiler::GeneralOptimizer; use acir::circuit::{Circuit, Opcode}; -pub struct R1CSOptimiser { +pub struct R1CSOptimizer { acir: Circuit, } -impl R1CSOptimiser { +impl R1CSOptimizer { pub fn new(acir: Circuit) -> Self { Self { acir } } - // R1CS optimisations uses the general optimiser. + // R1CS optimizations uses the general optimizer. // TODO: We could possibly make sure that all polynomials are at most degree-2 - pub fn optimise(self) -> Circuit { - let optimised_arith_gates: Vec<_> = self + pub fn optimize(self) -> Circuit { + let optimized_arith_gates: Vec<_> = self .acir .opcodes .into_iter() .map(|gate| match gate { - Opcode::Arithmetic(arith) => Opcode::Arithmetic(GeneralOptimiser::optimise(arith)), + Opcode::Arithmetic(arith) => Opcode::Arithmetic(GeneralOptimizer::optimize(arith)), other_gates => other_gates, }) .collect(); Circuit { - // The general optimiser may remove enough gates that a witness is no longer used + // The general optimizer may remove enough gates that a witness is no longer used // however, we cannot decrement the number of witnesses, as that // would require a linear scan over all gates in order to decrement all witness indices // above the witness which was removed current_witness_index: self.acir.current_witness_index, - opcodes: optimised_arith_gates, + opcodes: optimized_arith_gates, public_inputs: self.acir.public_inputs, } } diff --git a/acvm/src/compiler/optimiser/range_optimiser.rs b/acvm/src/compiler/optimizer/range_optimizer.rs similarity index 54% rename from acvm/src/compiler/optimiser/range_optimiser.rs rename to acvm/src/compiler/optimizer/range_optimizer.rs index 6d733cdb6..9bfa86d08 100644 --- a/acvm/src/compiler/optimiser/range_optimiser.rs +++ b/acvm/src/compiler/optimizer/range_optimizer.rs @@ -1,19 +1,19 @@ // XXX: We could alleviate a runtime check from noir // By casting directly // Example: -// priv z1 = x as u32 -// priv z2 = x as u16 +// priv z1 = x as u32 +// priv z2 = x as u16 // // The IR would see both casts and replace it with -// -// +// +// // priv z1 = x as u16; // priv z2 = x as u16; // -// -// Then maybe another optimisation could be done so that it transforms into +// +// Then maybe another optimization could be done so that it transforms into // // priv z1 = x as u16 // priv z2 = z1 -// This is what I would call a general optimisation, so it could live inside of the IR module -// A more specific optimisation would be to have z2 = z1 not use a gate (copy_from_to), this is more specific to plonk-aztec and would not live in this module \ No newline at end of file +// This is what I would call a general optimization, so it could live inside of the IR module +// A more specific optimization would be to have z2 = z1 not use a gate (copy_from_to), this is more specific to plonk-aztec and would not live in this module diff --git a/acvm/src/lib.rs b/acvm/src/lib.rs index c44fd9632..312fcc665 100644 --- a/acvm/src/lib.rs +++ b/acvm/src/lib.rs @@ -67,7 +67,7 @@ pub trait PartialWitnessGenerator { let resolution = match &opcode { Opcode::Arithmetic(expr) => ArithmeticSolver::solve(initial_witness, expr), Opcode::BlackBoxFuncCall(bb_func) => { - Self::solve_blackbox_function_call(initial_witness, bb_func) + Self::solve_black_box_function_call(initial_witness, bb_func) } Opcode::Directive(directive) => Self::solve_directives(initial_witness, directive), }; @@ -88,7 +88,7 @@ pub trait PartialWitnessGenerator { self.solve(initial_witness, unsolved_opcodes) } - fn solve_blackbox_function_call( + fn solve_black_box_function_call( initial_witness: &mut BTreeMap, func_call: &BlackBoxFuncCall, ) -> Result<(), OpcodeResolutionError>; @@ -139,11 +139,11 @@ pub trait ProofSystemCompiler { /// as this in most cases will be inefficient. For this reason, we want to throw a hard error /// if the language and proof system does not line up. fn np_language(&self) -> Language; - // Returns true if the backend supports the selected blackbox function - fn blackbox_function_supported(&self, opcode: &BlackBoxFunc) -> bool; + // Returns true if the backend supports the selected black box function + fn black_box_function_supported(&self, opcode: &BlackBoxFunc) -> bool; /// Creates a Proof given the circuit description and the witness values. - /// It is important to note that the intermediate witnesses for blackbox functions will not generated + /// It is important to note that the intermediate witnesses for black box functions will not generated /// This is the responsibility of the proof system. /// /// See `SmartContract` regarding the removal of `num_witnesses` and `num_public_inputs` @@ -180,7 +180,7 @@ pub enum Language { pub fn hash_constraint_system(cs: &Circuit) -> [u8; 32] { let mut bytes = Vec::new(); - cs.write(&mut bytes).expect("could not serialise circuit"); + cs.write(&mut bytes).expect("could not serialize circuit"); use sha2::{digest::FixedOutput, Digest, Sha256}; let mut hasher = Sha256::new(); @@ -190,15 +190,15 @@ pub fn hash_constraint_system(cs: &Circuit) -> [u8; 32] { } #[deprecated( - note = "For backwards compatibility, this method allows you to derive _sensible_ defaults for blackbox function support based on the np language. \n Backends should simply specify what they support." + note = "For backwards compatibility, this method allows you to derive _sensible_ defaults for black box function support based on the np language. \n Backends should simply specify what they support." )] // This is set to match the previous functionality that we had -// Where we could deduce what blackbox functions were supported +// Where we could deduce what black box functions were supported // by knowing the np complete language -pub fn default_is_blackbox_supported( +pub fn default_is_black_box_supported( language: Language, ) -> compiler::fallback::IsBlackBoxSupported { - // R1CS does not support any of the blackbox functions by default. + // R1CS does not support any of the black box functions by default. // The compiler will replace those that it can -- ie range, xor, and fn r1cs_is_supported(opcode: &BlackBoxFunc) -> bool { match opcode { @@ -206,15 +206,12 @@ pub fn default_is_blackbox_supported( } } - // PLONK supports most of the blackbox functions by default + // PLONK supports most of the black box functions by default // The ones which are not supported, the acvm compiler will // attempt to transform into supported gates. If these are also not available // then a compiler error will be emitted. fn plonk_is_supported(opcode: &BlackBoxFunc) -> bool { - match opcode { - BlackBoxFunc::AES => false, - _ => true, - } + !matches!(opcode, BlackBoxFunc::AES) } match language { diff --git a/acvm/src/pwg.rs b/acvm/src/pwg.rs index 1be89ef19..a650e7b89 100644 --- a/acvm/src/pwg.rs +++ b/acvm/src/pwg.rs @@ -11,11 +11,12 @@ use std::collections::BTreeMap; pub mod arithmetic; // Directives pub mod directives; -// blackbox functions +// black box functions pub mod hash; pub mod logic; pub mod range; pub mod signature; +pub mod sorting; // Returns the concrete value for a particular witness // If the witness has no assignment, then diff --git a/acvm/src/pwg/arithmetic.rs b/acvm/src/pwg/arithmetic.rs index 7ac522413..2d611db02 100644 --- a/acvm/src/pwg/arithmetic.rs +++ b/acvm/src/pwg/arithmetic.rs @@ -29,6 +29,7 @@ impl ArithmeticSolver { initial_witness: &mut BTreeMap, gate: &Expression, ) -> Result<(), OpcodeResolutionError> { + let gate = &ArithmeticSolver::evaluate(gate, initial_witness); // Evaluate multiplication term let mul_result = ArithmeticSolver::solve_mul_term(gate, initial_witness); // Evaluate the fan-in terms @@ -124,29 +125,44 @@ impl ArithmeticSolver { witness_assignments: &BTreeMap, ) -> MulTerm { // First note that the mul term can only contain one/zero term - // We are assuming it has been optimised. + // We are assuming it has been optimized. match arith_gate.mul_terms.len() { 0 => MulTerm::Solved(FieldElement::zero()), - 1 => { - let q_m = &arith_gate.mul_terms[0].0; - let w_l = &arith_gate.mul_terms[0].1; - let w_r = &arith_gate.mul_terms[0].2; - - // Check if these values are in the witness assignments - let w_l_value = witness_assignments.get(w_l); - let w_r_value = witness_assignments.get(w_r); - - match (w_l_value, w_r_value) { - (None, None) => MulTerm::TooManyUnknowns, - (Some(w_l), Some(w_r)) => MulTerm::Solved(*q_m * *w_l * *w_r), - (None, Some(w_r)) => MulTerm::OneUnknown(*q_m * *w_r, *w_l), - (Some(w_l), None) => MulTerm::OneUnknown(*q_m * *w_l, *w_r), - } - } + 1 => ArithmeticSolver::solve_mul_term_helper( + &arith_gate.mul_terms[0], + witness_assignments, + ), _ => panic!("Mul term in the arithmetic gate must contain either zero or one term"), } } + fn solve_mul_term_helper( + term: &(FieldElement, Witness, Witness), + witness_assignments: &BTreeMap, + ) -> MulTerm { + let (q_m, w_l, w_r) = term; + // Check if these values are in the witness assignments + let w_l_value = witness_assignments.get(w_l); + let w_r_value = witness_assignments.get(w_r); + + match (w_l_value, w_r_value) { + (None, None) => MulTerm::TooManyUnknowns, + (Some(w_l), Some(w_r)) => MulTerm::Solved(*q_m * *w_l * *w_r), + (None, Some(w_r)) => MulTerm::OneUnknown(*q_m * *w_r, *w_l), + (Some(w_l), None) => MulTerm::OneUnknown(*q_m * *w_l, *w_r), + } + } + + fn solve_fan_in_term_helper( + term: &(FieldElement, Witness), + witness_assignments: &BTreeMap, + ) -> Option { + let (q_l, w_l) = term; + // Check if we have w_l + let w_l_value = witness_assignments.get(w_l); + w_l_value.map(|a| *q_l * *a) + } + /// Returns the summation of all of the variables, plus the unknown variable /// Returns None, if there is more than one unknown variable /// We cannot assign @@ -163,19 +179,14 @@ impl ArithmeticSolver { let mut result = FieldElement::zero(); for term in arith_gate.linear_combinations.iter() { - let q_l = term.0; - let w_l = &term.1; - - // Check if we have w_l - let w_l_value = witness_assignments.get(w_l); - - match w_l_value { - Some(a) => result += q_l * *a, + let value = ArithmeticSolver::solve_fan_in_term_helper(term, witness_assignments); + match value { + Some(a) => result += a, None => { unknown_variable = *term; num_unknowns += 1; } - }; + } // If we have more than 1 unknown, then we cannot solve this equation if num_unknowns > 1 { @@ -189,6 +200,39 @@ impl ArithmeticSolver { GateStatus::GateSolvable(result, unknown_variable) } + + // Partially evaluate the gate using the known witnesses + pub fn evaluate( + expr: &Expression, + initial_witness: &BTreeMap, + ) -> Expression { + let mut result = Expression::default(); + for &(c, w1, w2) in &expr.mul_terms { + let mul_result = ArithmeticSolver::solve_mul_term_helper(&(c, w1, w2), initial_witness); + match mul_result { + MulTerm::OneUnknown(v, w) => { + if !v.is_zero() { + result.linear_combinations.push((v, w)); + } + } + MulTerm::TooManyUnknowns => { + if !c.is_zero() { + result.mul_terms.push((c, w1, w2)); + } + } + MulTerm::Solved(f) => result.q_c += f, + } + } + for &(c, w) in &expr.linear_combinations { + if let Some(f) = ArithmeticSolver::solve_fan_in_term_helper(&(c, w), initial_witness) { + result.q_c += f; + } else if !c.is_zero() { + result.linear_combinations.push((c, w)); + } + } + result.q_c += expr.q_c; + result + } } #[test] diff --git a/acvm/src/pwg/directives.rs b/acvm/src/pwg/directives.rs index 94bc2d60b..67819efb3 100644 --- a/acvm/src/pwg/directives.rs +++ b/acvm/src/pwg/directives.rs @@ -1,12 +1,16 @@ -use std::collections::BTreeMap; +use std::{cmp::Ordering, collections::BTreeMap}; -use acir::{circuit::directives::Directive, native_types::Witness, FieldElement}; +use acir::{ + circuit::directives::{Directive, LogInfo}, + native_types::Witness, + FieldElement, +}; use num_bigint::BigUint; use num_traits::{One, Zero}; use crate::OpcodeResolutionError; -use super::{get_value, witness_to_value}; +use super::{get_value, sorting::route, witness_to_value}; pub fn solve_directives( initial_witness: &mut BTreeMap, @@ -45,8 +49,16 @@ pub fn solve_directives( (&int_a % &int_b, &int_a / &int_b) }; - initial_witness.insert(*q, FieldElement::from_be_bytes_reduce(&int_q.to_bytes_be())); - initial_witness.insert(*r, FieldElement::from_be_bytes_reduce(&int_r.to_bytes_be())); + insert_witness( + *q, + FieldElement::from_be_bytes_reduce(&int_q.to_bytes_be()), + initial_witness, + )?; + insert_witness( + *r, + FieldElement::from_be_bytes_reduce(&int_r.to_bytes_be()), + initial_witness, + )?; Ok(()) } @@ -59,8 +71,16 @@ pub fn solve_directives( let int_b: BigUint = &int_a % &pow; let int_c: BigUint = (&int_a - &int_b) / &pow; - initial_witness.insert(*b, FieldElement::from_be_bytes_reduce(&int_b.to_bytes_be())); - initial_witness.insert(*c, FieldElement::from_be_bytes_reduce(&int_c.to_bytes_be())); + insert_witness( + *b, + FieldElement::from_be_bytes_reduce(&int_b.to_bytes_be()), + initial_witness, + )?; + insert_witness( + *c, + FieldElement::from_be_bytes_reduce(&int_c.to_bytes_be()), + initial_witness, + )?; Ok(()) } @@ -78,16 +98,7 @@ pub fn solve_directives( } else { FieldElement::zero() }; - match initial_witness.entry(b[i]) { - std::collections::btree_map::Entry::Vacant(e) => { - e.insert(v); - } - std::collections::btree_map::Entry::Occupied(e) => { - if e.get() != &v { - return Err(OpcodeResolutionError::UnsatisfiedConstrain); - } - } - } + insert_witness(b[i], v, initial_witness)?; } Ok(()) @@ -105,10 +116,115 @@ pub fn solve_directives( let int_r = &int_a - &bb; let int_b = &bb >> (bit_size - 1); - initial_witness.insert(*b, FieldElement::from_be_bytes_reduce(&int_b.to_bytes_be())); - initial_witness.insert(*r, FieldElement::from_be_bytes_reduce(&int_r.to_bytes_be())); + insert_witness( + *b, + FieldElement::from_be_bytes_reduce(&int_b.to_bytes_be()), + initial_witness, + )?; + insert_witness( + *r, + FieldElement::from_be_bytes_reduce(&int_r.to_bytes_be()), + initial_witness, + )?; + + Ok(()) + } + Directive::PermutationSort { + inputs: a, + tuple, + bits, + sort_by, + } => { + let mut val_a = Vec::new(); + let mut base = Vec::new(); + for (i, element) in a.iter().enumerate() { + assert_eq!(element.len(), *tuple as usize); + let mut element_val = Vec::with_capacity(*tuple as usize + 1); + for e in element { + element_val.push(get_value(e, initial_witness)?); + } + let field_i = FieldElement::from(i as i128); + element_val.push(field_i); + base.push(field_i); + val_a.push(element_val); + } + val_a.sort_by(|a, b| { + for i in sort_by { + let int_a = BigUint::from_bytes_be(&a[*i as usize].to_be_bytes()); + let int_b = BigUint::from_bytes_be(&b[*i as usize].to_be_bytes()); + let cmp = int_a.cmp(&int_b); + if cmp != Ordering::Equal { + return cmp; + } + } + Ordering::Equal + }); + let b = val_a.iter().map(|a| *a.last().unwrap()).collect(); + let control = route(base, b); + for (w, value) in bits.iter().zip(control) { + let value = if value { + FieldElement::one() + } else { + FieldElement::zero() + }; + insert_witness(*w, value, initial_witness)?; + } + Ok(()) + } + Directive::Log(info) => { + let witnesses = match info { + LogInfo::FinalizedOutput(output_string) => { + println!("{output_string}"); + return Ok(()); + } + LogInfo::WitnessOutput(witnesses) => witnesses, + }; + + if witnesses.len() == 1 { + let witness = &witnesses[0]; + let log_value = witness_to_value(initial_witness, *witness)?; + println!("{}", log_value.to_hex()); + + return Ok(()); + } + + // If multiple witnesses are to be fetched for a log directive, + // it assumed that an array is meant to be printed to standard output + // + // Collect all field element values corresponding to the given witness indices + // and convert them to hex strings. + let mut elements_as_hex = Vec::with_capacity(witnesses.len()); + for witness in witnesses { + let element = witness_to_value(initial_witness, *witness)?; + elements_as_hex.push(element.to_hex()); + } + + // Join all of the hex strings using a comma + let comma_separated_elements = elements_as_hex.join(","); + + let output_witnesses_string = "[".to_owned() + &comma_separated_elements + "]"; + + println!("{output_witnesses_string}"); Ok(()) } } } + +fn insert_witness( + w: Witness, + value: FieldElement, + initial_witness: &mut BTreeMap, +) -> Result<(), OpcodeResolutionError> { + match initial_witness.entry(w) { + std::collections::btree_map::Entry::Vacant(e) => { + e.insert(value); + } + std::collections::btree_map::Entry::Occupied(e) => { + if e.get() != &value { + return Err(OpcodeResolutionError::UnsatisfiedConstrain); + } + } + } + Ok(()) +} diff --git a/acvm/src/pwg/sorting.rs b/acvm/src/pwg/sorting.rs new file mode 100644 index 000000000..8a138c85c --- /dev/null +++ b/acvm/src/pwg/sorting.rs @@ -0,0 +1,390 @@ +use std::collections::{BTreeMap, BTreeSet}; + +use acir::FieldElement; + +// A sorting network is a graph of connected switches +// It is defined recursively so here we only keep track of the outer layer of switches +struct SortingNetwork { + n: usize, // size of the network + x_inputs: Vec, // inputs of the network + y_inputs: Vec, // outputs of the network + x_values: BTreeMap, // map for matching a y value with a x value + y_values: BTreeMap, // map for matching a x value with a y value + inner_x: Vec, // positions after the switch_x + inner_y: Vec, // positions after the sub-networks, and before the switch_y + switch_x: Vec, // outer switches for the inputs + switch_y: Vec, // outer switches for the outputs + free: BTreeSet, // outer switches available for looping +} + +impl SortingNetwork { + fn new(n: usize) -> SortingNetwork { + let free_len = (n - 1) / 2; + let mut free = BTreeSet::new(); + for i in 0..free_len { + free.insert(i); + } + SortingNetwork { + n, + x_inputs: Vec::with_capacity(n), + y_inputs: Vec::with_capacity(n), + x_values: BTreeMap::new(), + y_values: BTreeMap::new(), + inner_x: Vec::with_capacity(n), + inner_y: Vec::with_capacity(n), + switch_x: Vec::with_capacity(n / 2), + switch_y: Vec::with_capacity(free_len), + free, + } + } + + fn init(&mut self, inputs: Vec, outputs: Vec) { + let n = self.n; + assert_eq!(inputs.len(), outputs.len()); + assert_eq!(inputs.len(), n); + + self.x_inputs = inputs; + self.y_inputs = outputs; + for i in 0..self.n { + self.x_values.insert(self.x_inputs[i], i); + self.y_values.insert(self.y_inputs[i], i); + } + self.switch_x = vec![false; n / 2]; + self.switch_y = vec![false; (n - 1) / 2]; + self.inner_x = vec![FieldElement::zero(); n]; + self.inner_y = vec![FieldElement::zero(); n]; + + //Route the single wires so we do not need to handle this case later on + self.inner_y[n - 1] = self.y_inputs[n - 1]; + if n % 2 == 0 { + self.inner_y[n / 2 - 1] = self.y_inputs[n - 2]; + } else { + self.inner_x[n - 1] = self.x_inputs[n - 1]; + } + } + + //route a wire from outputs to its value in the inputs + fn route_out_wire(&mut self, y: usize, sub: bool) -> usize { + // sub <- y + if self.is_single_y(y) { + assert!(sub); + } else { + let port = y % 2 != 0; + let s1 = sub ^ port; + let inner = self.compute_inner(y, s1); + self.configure_y(y, s1, inner); + } + // x <- sub + let x = self.x_values.remove(&self.y_inputs[y]).unwrap(); + if !self.is_single_x(x) { + let port2 = x % 2 != 0; + let s2 = sub ^ port2; + let inner = self.compute_inner(x, s2); + self.configure_x(x, s2, inner); + } + x + } + + //route a wire from inputs to its value in the outputs + fn route_in_wire(&mut self, x: usize, sub: bool) -> usize { + // x -> sub + assert!(!self.is_single_x(x)); + let port = x % 2 != 0; + let s1 = sub ^ port; + let inner = self.compute_inner(x, s1); + self.configure_x(x, s1, inner); + + // sub -> y + let y = self.y_values.remove(&self.x_inputs[x]).unwrap(); + if !self.is_single_y(y) { + let port = y % 2 != 0; + let s2 = sub ^ port; + let inner = self.compute_inner(y, s2); + self.configure_y(y, s2, inner); + } + y + } + + //update the computed switch and inner values for an input wire + fn configure_x(&mut self, x: usize, switch: bool, inner: usize) { + self.inner_x[inner] = self.x_inputs[x]; + self.switch_x[x / 2] = switch; + } + + //update the computed switch and inner values for an output wire + fn configure_y(&mut self, y: usize, switch: bool, inner: usize) { + self.inner_y[inner] = self.y_inputs[y]; + self.switch_y[y / 2] = switch; + } + + // returns the other wire belonging to the same switch + fn sibling(index: usize) -> usize { + index + 1 - 2 * (index % 2) + } + + // returns a free switch + fn take(&mut self) -> Option { + self.free.first().copied() + } + + fn is_single_x(&self, a: usize) -> bool { + let n = self.x_inputs.len(); + n % 2 == 1 && a == n - 1 + } + + fn is_single_y(&mut self, a: usize) -> bool { + let n = self.x_inputs.len(); + a >= n - 2 + n % 2 + } + + // compute the inner position of idx through its switch + fn compute_inner(&self, idx: usize, switch: bool) -> usize { + if switch ^ (idx % 2 == 1) { + idx / 2 + self.n / 2 + } else { + idx / 2 + } + } + + fn new_start(&mut self) -> (Option, usize) { + let next = self.take(); + if let Some(switch) = next { + (next, 2 * switch) + } else { + (None, 0) + } + } +} + +// Computes the control bits of the sorting network which transform inputs into outputs +pub fn route(inputs: Vec, outputs: Vec) -> Vec { + assert_eq!(inputs.len(), outputs.len()); + match inputs.len() { + 0 => Vec::new(), + 1 => { + assert_eq!(inputs[0], outputs[0]); + Vec::new() + } + 2 => { + if inputs[0] == outputs[0] { + assert_eq!(inputs[1], outputs[1]); + vec![false] + } else { + assert_eq!(inputs[1], outputs[0]); + assert_eq!(inputs[0], outputs[1]); + vec![true] + } + } + _ => { + let n = inputs.len(); + + let mut result; + let n1 = n / 2; + let in_sub1; + let out_sub1; + let in_sub2; + let out_sub2; + + // process the outer layer in a code block so that the intermediate data is cleared before recursion + { + let mut network = SortingNetwork::new(n); + network.init(inputs, outputs); + + //We start with the last single wire + let mut out_idx = n - 1; + let mut start_sub = true; //it is connected to the lower inner network + let mut switch = None; + let mut start = None; + + while !network.free.is_empty() { + // the processed switch is no more available + if let Some(free_switch) = switch { + network.free.remove(&free_switch); + } + + // connect the output wire to its matching input + let in_idx = network.route_out_wire(out_idx, start_sub); + if network.is_single_x(in_idx) { + start_sub = !start_sub; //We need to restart, but did not complete the loop so we switch the subnetwork + (start, out_idx) = network.new_start(); + switch = start; + continue; + } + + // loop from the sibling + let next = SortingNetwork::sibling(in_idx); + // connect the input wire to its matching output, using the other sub-network + out_idx = network.route_in_wire(next, !start_sub); + switch = Some(out_idx / 2); + if start == switch || network.is_single_y(out_idx) { + //loop is complete, need a fresh start + (start, out_idx) = network.new_start(); + switch = start; + } else { + // we loop back from the sibling + out_idx = SortingNetwork::sibling(out_idx); + } + } + //All the wires are connected, we can now route the sub-networks + result = network.switch_x; + result.extend(network.switch_y); + in_sub1 = network.inner_x[0..n1].to_vec(); + in_sub2 = network.inner_x[n1..].to_vec(); + out_sub1 = network.inner_y[0..n1].to_vec(); + out_sub2 = network.inner_y[n1..].to_vec(); + } + let s1 = route(in_sub1, out_sub1); + result.extend(s1); + let s2 = route(in_sub2, out_sub2); + result.extend(s2); + result + } + } +} + +#[cfg(test)] +mod test { + use crate::pwg::sorting::route; + use acir::FieldElement; + use rand::prelude::*; + + pub fn execute_network(config: Vec, inputs: Vec) -> Vec { + let n = inputs.len(); + if n == 1 { + return inputs; + } + let mut in1 = Vec::new(); + let mut in2 = Vec::new(); + //layer 1: + for i in 0..n / 2 { + if config[i] { + in1.push(inputs[2 * i + 1]); + in2.push(inputs[2 * i]); + } else { + in1.push(inputs[2 * i]); + in2.push(inputs[2 * i + 1]); + } + } + if n % 2 == 1 { + in2.push(*inputs.last().unwrap()); + } + let n2 = n / 2 + (n - 1) / 2; + let n3 = n2 + switch_nb(n / 2); + let mut result = Vec::new(); + let out1 = execute_network(config[n2..n3].to_vec(), in1); + let out2 = execute_network(config[n3..].to_vec(), in2); + //last layer: + for i in 0..(n - 1) / 2 { + if config[n / 2 + i] { + result.push(out2[i]); + result.push(out1[i]); + } else { + result.push(out1[i]); + result.push(out2[i]); + } + } + if n % 2 == 0 { + result.push(*out1.last().unwrap()); + result.push(*out2.last().unwrap()); + } else { + result.push(*out2.last().unwrap()) + } + result + } + + pub fn switch_nb(n: usize) -> usize { + let mut s = 0; + for i in 0..n { + s += f64::from((i + 1) as u32).log2().ceil() as usize; + } + s + } + + #[test] + fn test_route() { + //basic tests + let a = vec![ + FieldElement::from(1_i128), + FieldElement::from(2_i128), + FieldElement::from(3_i128), + ]; + let b = vec![ + FieldElement::from(1_i128), + FieldElement::from(2_i128), + FieldElement::from(3_i128), + ]; + let c = route(a, b); + assert_eq!(c, vec![false, false, false]); + + let a = vec![ + FieldElement::from(1_i128), + FieldElement::from(2_i128), + FieldElement::from(3_i128), + ]; + let b = vec![ + FieldElement::from(1_i128), + FieldElement::from(3_i128), + FieldElement::from(2_i128), + ]; + let c = route(a, b); + assert_eq!(c, vec![false, false, true]); + + let a = vec![ + FieldElement::from(1_i128), + FieldElement::from(2_i128), + FieldElement::from(3_i128), + ]; + let b = vec![ + FieldElement::from(3_i128), + FieldElement::from(2_i128), + FieldElement::from(1_i128), + ]; + let c = route(a, b); + assert_eq!(c, vec![true, true, true]); + + let a = vec![ + FieldElement::from(0_i128), + FieldElement::from(1_i128), + FieldElement::from(2_i128), + FieldElement::from(3_i128), + ]; + let b = vec![ + FieldElement::from(2_i128), + FieldElement::from(3_i128), + FieldElement::from(0_i128), + FieldElement::from(1_i128), + ]; + let c = route(a, b); + assert_eq!(c, vec![false, true, true, true, true]); + + let a = vec![ + FieldElement::from(0_i128), + FieldElement::from(1_i128), + FieldElement::from(2_i128), + FieldElement::from(3_i128), + FieldElement::from(4_i128), + ]; + let b = vec![ + FieldElement::from(0_i128), + FieldElement::from(3_i128), + FieldElement::from(4_i128), + FieldElement::from(2_i128), + FieldElement::from(1_i128), + ]; + let c = route(a, b); + assert_eq!(c, vec![false, false, false, true, false, true, false, true]); + + // random tests + for i in 2..50 { + let mut a = vec![FieldElement::zero()]; + for j in 0..i - 1 { + a.push(a[j] + FieldElement::one()); + } + let mut rng = rand::thread_rng(); + let mut b = a.clone(); + b.shuffle(&mut rng); + let c = route(a.clone(), b.clone()); + assert_eq!(b, execute_network(c, a)); + } + } +} diff --git a/cspell.json b/cspell.json new file mode 100644 index 000000000..8a4bf70c6 --- /dev/null +++ b/cspell.json @@ -0,0 +1,44 @@ +{ + "version": "0.2", + "words": [ + "blackbox", + // In code + // + "acir", + "ACIR", + "ACVM", + "Axyz", + "arithmetization", + "bivariate", + "canonicalize", + "coeff", + "consts", + "csat", + "decomp", + "deflater", + "endianness", + "euclidian", + "hasher", + "Merkle", + "OddRange", + "Pedersen", + "PLONKC", + "prehashed", + "pubkey", + "repr", + "secp", + "Schnorr", + "Shleft", + "Shright", + "stdlib", + "struct", + "TORADIX", + // Dependencies + // + "bufread", + "flate", + "indexmap", + "thiserror", + "typenum" + ] +} diff --git a/stdlib/src/fallback.rs b/stdlib/src/fallback.rs index 818c8b874..9aaaa2c37 100644 --- a/stdlib/src/fallback.rs +++ b/stdlib/src/fallback.rs @@ -68,7 +68,7 @@ pub(crate) fn bit_decomposition( bit_decomp_constraint.sort(); // TODO: we have an issue open to check if this is needed. Ideally, we remove it. new_gates.push(Opcode::Arithmetic(bit_decomp_constraint)); - (new_gates, bit_vector, variables.finalise()) + (new_gates, bit_vector, variables.finalize()) } // Range constraint @@ -140,7 +140,7 @@ pub fn xor( let two = FieldElement::from(2_i128); // Build an xor expression - // TODO: check this is the correct arithmetisation + // TODO: check this is the correct arithmetization let mut xor_expr = Expression::default(); for (a_bit, b_bit) in a_bits.into_iter().zip(b_bits) { xor_expr.term_addition(two_pow, a_bit); diff --git a/stdlib/src/helpers.rs b/stdlib/src/helpers.rs index 4eea6a518..5ab258368 100644 --- a/stdlib/src/helpers.rs +++ b/stdlib/src/helpers.rs @@ -17,7 +17,7 @@ impl<'a> VariableStore<'a> { witness } - pub fn finalise(self) -> u32 { + pub fn finalize(self) -> u32 { *self.witness_index } } From 248bd423d23ed3e26118063eb52a77c519d69f8c Mon Sep 17 00:00:00 2001 From: guipublic Date: Mon, 6 Feb 2023 16:28:57 +0000 Subject: [PATCH 2/3] spell check --- acir/src/circuit/directives.rs | 4 ++-- acvm/src/pwg/sorting.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/acir/src/circuit/directives.rs b/acir/src/circuit/directives.rs index 585cd6a9b..5c4026f46 100644 --- a/acir/src/circuit/directives.rs +++ b/acir/src/circuit/directives.rs @@ -50,12 +50,12 @@ pub enum Directive { }, // Sort directive, using a sorting network - // This directive is used to generate the values of the control bits for the sorting network such that its ouputs are properly sorted accroding to sort_by + // This directive is used to generate the values of the control bits for the sorting network such that its outputs are properly sorted accroding to sort_by PermutationSort { inputs: Vec>, // Array of tuples to sort tuple: u32, // tuple size; if 1 then inputs is a single array [a0,a1,..], if 2 then inputs=[(a0,b0),..] is [a0,b0,a1,b1,..], etc.. bits: Vec, // control bits of the network which permutes the inputs into its sorted version - sort_by: Vec, // specify primary index to sort by, then the secondary,... For instance, if typle is 2 and sort_by is [1,0], then a=[(a0,b0),..] is sorted by bi and then ai. + sort_by: Vec, // specify primary index to sort by, then the secondary,... For instance, if tuple is 2 and sort_by is [1,0], then a=[(a0,b0),..] is sorted by bi and then ai. }, Log(LogInfo), } diff --git a/acvm/src/pwg/sorting.rs b/acvm/src/pwg/sorting.rs index 8a138c85c..beecafad4 100644 --- a/acvm/src/pwg/sorting.rs +++ b/acvm/src/pwg/sorting.rs @@ -205,7 +205,7 @@ pub fn route(inputs: Vec, outputs: Vec) -> Vec // connect the output wire to its matching input let in_idx = network.route_out_wire(out_idx, start_sub); if network.is_single_x(in_idx) { - start_sub = !start_sub; //We need to restart, but did not complete the loop so we switch the subnetwork + start_sub = !start_sub; //We need to restart, but did not complete the loop so we switch the sub network (start, out_idx) = network.new_start(); switch = start; continue; From dc04e89f7c4a147800e2f021291519fefca85ab8 Mon Sep 17 00:00:00 2001 From: guipublic Date: Tue, 7 Feb 2023 12:05:51 +0000 Subject: [PATCH 3/3] some comments --- acir/src/circuit/directives.rs | 2 +- acvm/src/pwg/sorting.rs | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/acir/src/circuit/directives.rs b/acir/src/circuit/directives.rs index 5c4026f46..f3f5e6aca 100644 --- a/acir/src/circuit/directives.rs +++ b/acir/src/circuit/directives.rs @@ -50,7 +50,7 @@ pub enum Directive { }, // Sort directive, using a sorting network - // This directive is used to generate the values of the control bits for the sorting network such that its outputs are properly sorted accroding to sort_by + // This directive is used to generate the values of the control bits for the sorting network such that its outputs are properly sorted according to sort_by PermutationSort { inputs: Vec>, // Array of tuples to sort tuple: u32, // tuple size; if 1 then inputs is a single array [a0,a1,..], if 2 then inputs=[(a0,b0),..] is [a0,b0,a1,b1,..], etc.. diff --git a/acvm/src/pwg/sorting.rs b/acvm/src/pwg/sorting.rs index beecafad4..217d0c0c2 100644 --- a/acvm/src/pwg/sorting.rs +++ b/acvm/src/pwg/sorting.rs @@ -157,6 +157,7 @@ impl SortingNetwork { } // Computes the control bits of the sorting network which transform inputs into outputs +// implementation is based on https://www.mdpi.com/2227-7080/10/1/16 pub fn route(inputs: Vec, outputs: Vec) -> Vec { assert_eq!(inputs.len(), outputs.len()); match inputs.len() { @@ -292,6 +293,7 @@ mod test { result } + // returns the number of switches in the network pub fn switch_nb(n: usize) -> usize { let mut s = 0; for i in 0..n {