From 64b436b458c7da35eb7aa395362983fb66134c02 Mon Sep 17 00:00:00 2001 From: guipublic Date: Fri, 26 Jan 2024 11:49:44 +0000 Subject: [PATCH] Add opcode for sha256 compression function --- .../dsl/acir_format/serde/acir.hpp | 141 +++++++++++++++++- noir/acvm-repo/acir/codegen/acir.cpp | 112 +++++++++++++- .../acir/src/circuit/black_box_functions.rs | 5 + .../opcodes/black_box_function_call.rs | 22 ++- .../acvm/src/compiler/transformers/mod.rs | 3 + noir/acvm-repo/acvm/src/pwg/blackbox/mod.rs | 1 + noir/acvm-repo/brillig/src/black_box.rs | 5 + noir/acvm-repo/brillig_vm/src/black_box.rs | 2 + .../brillig/brillig_gen/brillig_black_box.rs | 17 ++- .../src/brillig/brillig_ir/debug_show.rs | 9 ++ .../ssa/acir_gen/acir_ir/generated_acir.rs | 8 + .../src/ssa/ir/instruction/call.rs | 1 + noir/noir_stdlib/src/hash.nr | 3 + 13 files changed, 322 insertions(+), 7 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp index e017ee8a3e9..97e73a25c69 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp @@ -265,6 +265,16 @@ struct BlackBoxFuncCall { static Poseidon2Permutation bincodeDeserialize(std::vector); }; + struct Sha256Compression { + std::vector inputs; + std::vector hash_values; + std::vector outputs; + + friend bool operator==(const Sha256Compression&, const Sha256Compression&); + std::vector bincodeSerialize() const; + static Sha256Compression bincodeDeserialize(std::vector); + }; + std::variant + Poseidon2Permutation, + Sha256Compression> value; friend bool operator==(const BlackBoxFuncCall&, const BlackBoxFuncCall&); @@ -685,6 +696,16 @@ struct BlackBoxOp { static Poseidon2Permutation bincodeDeserialize(std::vector); }; + struct Sha256Compression { + Circuit::HeapVector input; + Circuit::HeapVector hash_values; + Circuit::HeapArray output; + + friend bool operator==(const Sha256Compression&, const Sha256Compression&); + std::vector bincodeSerialize() const; + static Sha256Compression bincodeDeserialize(std::vector); + }; + std::variant + Poseidon2Permutation, + Sha256Compression> value; friend bool operator==(const BlackBoxOp&, const BlackBoxOp&); @@ -3350,6 +3372,64 @@ Circuit::BlackBoxFuncCall::Poseidon2Permutation serde::Deserializable< namespace Circuit { +inline bool operator==(const BlackBoxFuncCall::Sha256Compression& lhs, const BlackBoxFuncCall::Sha256Compression& rhs) +{ + if (!(lhs.inputs == rhs.inputs)) { + return false; + } + if (!(lhs.hash_values == rhs.hash_values)) { + return false; + } + if (!(lhs.outputs == rhs.outputs)) { + return false; + } + return true; +} + +inline std::vector BlackBoxFuncCall::Sha256Compression::bincodeSerialize() const +{ + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); +} + +inline BlackBoxFuncCall::Sha256Compression BlackBoxFuncCall::Sha256Compression::bincodeDeserialize( + std::vector input) +{ + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw_or_abort("Some input bytes were not read"); + } + return value; +} + +} // end of namespace Circuit + +template <> +template +void serde::Serializable::serialize( + const Circuit::BlackBoxFuncCall::Sha256Compression& obj, Serializer& serializer) +{ + serde::Serializable::serialize(obj.inputs, serializer); + serde::Serializable::serialize(obj.hash_values, serializer); + serde::Serializable::serialize(obj.outputs, serializer); +} + +template <> +template +Circuit::BlackBoxFuncCall::Sha256Compression serde::Deserializable< + Circuit::BlackBoxFuncCall::Sha256Compression>::deserialize(Deserializer& deserializer) +{ + Circuit::BlackBoxFuncCall::Sha256Compression obj; + obj.inputs = serde::Deserializable::deserialize(deserializer); + obj.hash_values = serde::Deserializable::deserialize(deserializer); + obj.outputs = serde::Deserializable::deserialize(deserializer); + return obj; +} + +namespace Circuit { + inline bool operator==(const BlackBoxOp& lhs, const BlackBoxOp& rhs) { if (!(lhs.value == rhs.value)) { @@ -4490,6 +4570,63 @@ Circuit::BlackBoxOp::Poseidon2Permutation serde::Deserializable BlackBoxOp::Sha256Compression::bincodeSerialize() const +{ + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); +} + +inline BlackBoxOp::Sha256Compression BlackBoxOp::Sha256Compression::bincodeDeserialize(std::vector input) +{ + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw_or_abort("Some input bytes were not read"); + } + return value; +} + +} // end of namespace Circuit + +template <> +template +void serde::Serializable::serialize( + const Circuit::BlackBoxOp::Sha256Compression& obj, Serializer& serializer) +{ + serde::Serializable::serialize(obj.input, serializer); + serde::Serializable::serialize(obj.hash_values, serializer); + serde::Serializable::serialize(obj.output, serializer); +} + +template <> +template +Circuit::BlackBoxOp::Sha256Compression serde::Deserializable::deserialize( + Deserializer& deserializer) +{ + Circuit::BlackBoxOp::Sha256Compression obj; + obj.input = serde::Deserializable::deserialize(deserializer); + obj.hash_values = serde::Deserializable::deserialize(deserializer); + obj.output = serde::Deserializable::deserialize(deserializer); + return obj; +} + +namespace Circuit { + inline bool operator==(const BlockId& lhs, const BlockId& rhs) { if (!(lhs.value == rhs.value)) { diff --git a/noir/acvm-repo/acir/codegen/acir.cpp b/noir/acvm-repo/acir/codegen/acir.cpp index 487bb33a6b2..0f94e91ab10 100644 --- a/noir/acvm-repo/acir/codegen/acir.cpp +++ b/noir/acvm-repo/acir/codegen/acir.cpp @@ -265,7 +265,17 @@ namespace Circuit { static Poseidon2Permutation bincodeDeserialize(std::vector); }; - std::variant value; + struct Sha256Compression { + std::vector inputs; + std::vector hash_values; + std::vector outputs; + + friend bool operator==(const Sha256Compression&, const Sha256Compression&); + std::vector bincodeSerialize() const; + static Sha256Compression bincodeDeserialize(std::vector); + }; + + std::variant value; friend bool operator==(const BlackBoxFuncCall&, const BlackBoxFuncCall&); std::vector bincodeSerialize() const; @@ -661,7 +671,17 @@ namespace Circuit { static Poseidon2Permutation bincodeDeserialize(std::vector); }; - std::variant value; + struct Sha256Compression { + Circuit::HeapVector input; + Circuit::HeapVector hash_values; + Circuit::HeapArray output; + + friend bool operator==(const Sha256Compression&, const Sha256Compression&); + std::vector bincodeSerialize() const; + static Sha256Compression bincodeDeserialize(std::vector); + }; + + std::variant value; friend bool operator==(const BlackBoxOp&, const BlackBoxOp&); std::vector bincodeSerialize() const; @@ -2848,6 +2868,50 @@ Circuit::BlackBoxFuncCall::Poseidon2Permutation serde::Deserializable BlackBoxFuncCall::Sha256Compression::bincodeSerialize() const { + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); + } + + inline BlackBoxFuncCall::Sha256Compression BlackBoxFuncCall::Sha256Compression::bincodeDeserialize(std::vector input) { + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw serde::deserialization_error("Some input bytes were not read"); + } + return value; + } + +} // end of namespace Circuit + +template <> +template +void serde::Serializable::serialize(const Circuit::BlackBoxFuncCall::Sha256Compression &obj, Serializer &serializer) { + serde::Serializable::serialize(obj.inputs, serializer); + serde::Serializable::serialize(obj.hash_values, serializer); + serde::Serializable::serialize(obj.outputs, serializer); +} + +template <> +template +Circuit::BlackBoxFuncCall::Sha256Compression serde::Deserializable::deserialize(Deserializer &deserializer) { + Circuit::BlackBoxFuncCall::Sha256Compression obj; + obj.inputs = serde::Deserializable::deserialize(deserializer); + obj.hash_values = serde::Deserializable::deserialize(deserializer); + obj.outputs = serde::Deserializable::deserialize(deserializer); + return obj; +} + namespace Circuit { inline bool operator==(const BlackBoxOp &lhs, const BlackBoxOp &rhs) { @@ -3732,6 +3796,50 @@ Circuit::BlackBoxOp::Poseidon2Permutation serde::Deserializable BlackBoxOp::Sha256Compression::bincodeSerialize() const { + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); + } + + inline BlackBoxOp::Sha256Compression BlackBoxOp::Sha256Compression::bincodeDeserialize(std::vector input) { + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw serde::deserialization_error("Some input bytes were not read"); + } + return value; + } + +} // end of namespace Circuit + +template <> +template +void serde::Serializable::serialize(const Circuit::BlackBoxOp::Sha256Compression &obj, Serializer &serializer) { + serde::Serializable::serialize(obj.input, serializer); + serde::Serializable::serialize(obj.hash_values, serializer); + serde::Serializable::serialize(obj.output, serializer); +} + +template <> +template +Circuit::BlackBoxOp::Sha256Compression serde::Deserializable::deserialize(Deserializer &deserializer) { + Circuit::BlackBoxOp::Sha256Compression obj; + obj.input = serde::Deserializable::deserialize(deserializer); + obj.hash_values = serde::Deserializable::deserialize(deserializer); + obj.output = serde::Deserializable::deserialize(deserializer); + return obj; +} + namespace Circuit { inline bool operator==(const BlockId &lhs, const BlockId &rhs) { diff --git a/noir/acvm-repo/acir/src/circuit/black_box_functions.rs b/noir/acvm-repo/acir/src/circuit/black_box_functions.rs index 358722900ba..97b4759d350 100644 --- a/noir/acvm-repo/acir/src/circuit/black_box_functions.rs +++ b/noir/acvm-repo/acir/src/circuit/black_box_functions.rs @@ -61,6 +61,8 @@ pub enum BlackBoxFunc { BigIntToLeBytes, /// Permutation function of Poseidon2 Poseidon2Permutation, + /// SHA256 compression function + Sha256Compression, } impl std::fmt::Display for BlackBoxFunc { @@ -95,6 +97,7 @@ impl BlackBoxFunc { BlackBoxFunc::BigIntFromLeBytes => "bigint_from_le_bytes", BlackBoxFunc::BigIntToLeBytes => "bigint_to_le_bytes", BlackBoxFunc::Poseidon2Permutation => "poseidon2_permutation", + BlackBoxFunc::Sha256Compression => "sha256_compression", } } @@ -123,9 +126,11 @@ impl BlackBoxFunc { "bigint_from_le_bytes" => Some(BlackBoxFunc::BigIntFromLeBytes), "bigint_to_le_bytes" => Some(BlackBoxFunc::BigIntToLeBytes), "poseidon2_permutation" => Some(BlackBoxFunc::Poseidon2Permutation), + "sha256_compression" => Some(BlackBoxFunc::Sha256Compression), _ => None, } } + pub fn is_valid_black_box_func_name(op_name: &str) -> bool { BlackBoxFunc::lookup(op_name).is_some() } diff --git a/noir/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs b/noir/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs index 110a524f746..ba4964c8912 100644 --- a/noir/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs +++ b/noir/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs @@ -155,6 +155,21 @@ pub enum BlackBoxFuncCall { /// It is the length of inputs and outputs vectors len: u32, }, + /// Applies the SHA-256 compression function to the input message + /// + /// # Arguments + /// + /// * `inputs` - input message block + /// * `hash_values` - state from the previous compression + /// * `outputs` - result of the input compressed into 256 bits + Sha256Compression { + /// 512 bits of the input message, represented by 16 u32s + inputs: Vec, + /// Vector of 8 u32s used to compress the input + hash_values: Vec, + /// Output of the compression, represented by 8 u32s + outputs: Vec, + }, } impl BlackBoxFuncCall { @@ -184,6 +199,7 @@ impl BlackBoxFuncCall { BlackBoxFuncCall::BigIntFromLeBytes { .. } => BlackBoxFunc::BigIntFromLeBytes, BlackBoxFuncCall::BigIntToLeBytes { .. } => BlackBoxFunc::BigIntToLeBytes, BlackBoxFuncCall::Poseidon2Permutation { .. } => BlackBoxFunc::Poseidon2Permutation, + BlackBoxFuncCall::Sha256Compression { .. } => BlackBoxFunc::Sha256Compression, } } @@ -201,7 +217,8 @@ impl BlackBoxFuncCall { | BlackBoxFuncCall::PedersenCommitment { inputs, .. } | BlackBoxFuncCall::PedersenHash { inputs, .. } | BlackBoxFuncCall::BigIntFromLeBytes { inputs, .. } - | BlackBoxFuncCall::Poseidon2Permutation { inputs, .. } => inputs.to_vec(), + | BlackBoxFuncCall::Poseidon2Permutation { inputs, .. } + | BlackBoxFuncCall::Sha256Compression { inputs, .. } => inputs.to_vec(), BlackBoxFuncCall::AND { lhs, rhs, .. } | BlackBoxFuncCall::XOR { lhs, rhs, .. } => { vec![*lhs, *rhs] } @@ -296,7 +313,8 @@ impl BlackBoxFuncCall { | BlackBoxFuncCall::Keccak256 { outputs, .. } | BlackBoxFuncCall::Keccakf1600 { outputs, .. } | BlackBoxFuncCall::Keccak256VariableLength { outputs, .. } - | BlackBoxFuncCall::Poseidon2Permutation { outputs, .. } => outputs.to_vec(), + | BlackBoxFuncCall::Poseidon2Permutation { outputs, .. } + | BlackBoxFuncCall::Sha256Compression { outputs, .. } => outputs.to_vec(), BlackBoxFuncCall::AND { output, .. } | BlackBoxFuncCall::XOR { output, .. } | BlackBoxFuncCall::SchnorrVerify { output, .. } diff --git a/noir/acvm-repo/acvm/src/compiler/transformers/mod.rs b/noir/acvm-repo/acvm/src/compiler/transformers/mod.rs index 246eeadc095..970eb9390bb 100644 --- a/noir/acvm-repo/acvm/src/compiler/transformers/mod.rs +++ b/noir/acvm-repo/acvm/src/compiler/transformers/mod.rs @@ -126,6 +126,9 @@ pub(super) fn transform_internal( | acir::circuit::opcodes::BlackBoxFuncCall::Poseidon2Permutation { outputs, .. + } + | acir::circuit::opcodes::BlackBoxFuncCall::Sha256Compression { + outputs, .. } => { for witness in outputs { transformer.mark_solvable(*witness); diff --git a/noir/acvm-repo/acvm/src/pwg/blackbox/mod.rs b/noir/acvm-repo/acvm/src/pwg/blackbox/mod.rs index 3baf99710ad..0f026cd274a 100644 --- a/noir/acvm-repo/acvm/src/pwg/blackbox/mod.rs +++ b/noir/acvm-repo/acvm/src/pwg/blackbox/mod.rs @@ -197,5 +197,6 @@ pub(crate) fn solve( BlackBoxFuncCall::BigIntFromLeBytes { .. } => todo!(), BlackBoxFuncCall::BigIntToLeBytes { .. } => todo!(), BlackBoxFuncCall::Poseidon2Permutation { .. } => todo!(), + BlackBoxFuncCall::Sha256Compression { .. } => todo!(), } } diff --git a/noir/acvm-repo/brillig/src/black_box.rs b/noir/acvm-repo/brillig/src/black_box.rs index f5f5c53803e..22fac6f3ba3 100644 --- a/noir/acvm-repo/brillig/src/black_box.rs +++ b/noir/acvm-repo/brillig/src/black_box.rs @@ -114,4 +114,9 @@ pub enum BlackBoxOp { output: HeapArray, len: RegisterIndex, }, + Sha256Compression { + input: HeapVector, + hash_values: HeapVector, + output: HeapArray, + }, } diff --git a/noir/acvm-repo/brillig_vm/src/black_box.rs b/noir/acvm-repo/brillig_vm/src/black_box.rs index 9935005a5ea..e9c25200c47 100644 --- a/noir/acvm-repo/brillig_vm/src/black_box.rs +++ b/noir/acvm-repo/brillig_vm/src/black_box.rs @@ -200,6 +200,7 @@ pub(crate) fn evaluate_black_box( BlackBoxOp::BigIntFromLeBytes { .. } => todo!(), BlackBoxOp::BigIntToLeBytes { .. } => todo!(), BlackBoxOp::Poseidon2Permutation { .. } => todo!(), + BlackBoxOp::Sha256Compression { .. } => todo!(), } } @@ -224,6 +225,7 @@ fn black_box_function_from_op(op: &BlackBoxOp) -> BlackBoxFunc { BlackBoxOp::BigIntFromLeBytes { .. } => BlackBoxFunc::BigIntFromLeBytes, BlackBoxOp::BigIntToLeBytes { .. } => BlackBoxFunc::BigIntToLeBytes, BlackBoxOp::Poseidon2Permutation { .. } => BlackBoxFunc::Poseidon2Permutation, + BlackBoxOp::Sha256Compression { .. } => BlackBoxFunc::Sha256Compression, } } diff --git a/noir/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_black_box.rs b/noir/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_black_box.rs index 62a1cef50a1..96d80cb8131 100644 --- a/noir/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_black_box.rs +++ b/noir/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_black_box.rs @@ -341,7 +341,22 @@ pub(crate) fn convert_black_box_call( len: *state_len, }); } else { - unreachable!("ICE: SHA256 expects one array argument and one array result") + unreachable!("ICE: Poseidon2Permutation expects one array argument, a length and one array result") + } + } + BlackBoxFunc::Sha256Compression => { + if let ([message, hash_values], [BrilligVariable::BrilligArray(result_array)]) = + (function_arguments, function_results) + { + let message_vector = convert_array_or_vector(brillig_context, message, bb_func); + let hash_vector = convert_array_or_vector(brillig_context, hash_values, bb_func); + brillig_context.black_box_op_instruction(BlackBoxOp::Sha256Compression { + input: message_vector.to_heap_vector(), + hash_values: hash_vector.to_heap_vector(), + output: result_array.to_heap_array(), + }); + } else { + unreachable!("ICE: Sha256Compression expects two array argument, one array result") } } } diff --git a/noir/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs b/noir/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs index b36081d5e3e..a8563dc9efe 100644 --- a/noir/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs +++ b/noir/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs @@ -510,6 +510,15 @@ impl DebugShow { output ); } + BlackBoxOp::Sha256Compression { input, hash_values, output } => { + debug_println!( + self.enable_debug_trace, + " SHA256COMPRESSION {} {} -> {}", + input, + hash_values, + output + ); + } } } diff --git a/noir/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs b/noir/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs index 1730a1f6a20..b86fc4eeb5f 100644 --- a/noir/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs +++ b/noir/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs @@ -281,6 +281,11 @@ impl GeneratedAcir { outputs, len: constant_inputs[0].to_u128() as u32, }, + BlackBoxFunc::Sha256Compression => BlackBoxFuncCall::Sha256Compression { + inputs: inputs[0].clone(), + hash_values: inputs[1].clone(), + outputs, + }, }; self.push_opcode(AcirOpcode::BlackBoxFuncCall(black_box_func_call)); @@ -617,6 +622,8 @@ fn black_box_func_expected_input_size(name: BlackBoxFunc) -> Option { // The permutation takes a fixed number of inputs, but the inputs length depends on the proving system implementation. BlackBoxFunc::Poseidon2Permutation => None, + // SHA256 compression requires 16 u32s as input message and 8 u32s for the hash state. + BlackBoxFunc::Sha256Compression => Some(24), // Can only apply a range constraint to one // witness at a time. BlackBoxFunc::RANGE => Some(1), @@ -667,6 +674,7 @@ fn black_box_expected_output_size(name: BlackBoxFunc) -> Option { // The permutation returns a fixed number of outputs, equals to the inputs length which depends on the proving system implementation. BlackBoxFunc::Poseidon2Permutation => None, + BlackBoxFunc::Sha256Compression => Some(8), // Pedersen commitment returns a point BlackBoxFunc::PedersenCommitment => Some(2), diff --git a/noir/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs b/noir/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs index eab839d9569..0178ae9dba1 100644 --- a/noir/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs +++ b/noir/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs @@ -452,6 +452,7 @@ fn simplify_black_box_func( "ICE: `BlackBoxFunc::RANGE` calls should be transformed into a `Instruction::Cast`" ) } + BlackBoxFunc::Sha256Compression => SimplifyResult::None, //TODO(Guillaume) } } diff --git a/noir/noir_stdlib/src/hash.nr b/noir/noir_stdlib/src/hash.nr index d53729f423f..c82d7722ca8 100644 --- a/noir/noir_stdlib/src/hash.nr +++ b/noir/noir_stdlib/src/hash.nr @@ -53,3 +53,6 @@ pub fn keccak256(_input: [u8; N], _message_size: u32) -> [u8; 32] {} #[foreign(poseidon2_permutation)] pub fn poseidon2_permutation(_input: [u8; N], _state_length: u32) -> [u8; N] {} + +#[foreign(sha256_compression)] +pub fn sha256_compression(_input: [u32; 16], _state: [u32; 8]) -> [u32; 8] {}