Skip to content

Commit

Permalink
chore: constant inputs for most blackboxes (#7613)
Browse files Browse the repository at this point in the history
blackbox can use constant inputs:
aes128
keccakf1600
sha256compression
ecoperations
blake2s
Blake3s
logic
sha256compression

The following have not been modified to use constant inputs:
sha256 and keccak256 because they are deprecated.
block, recursion and honk recursion because I am not sure they support
it.
shnorr and ecdsa, because they might be replaced.
  • Loading branch information
guipublic authored Jul 29, 2024
1 parent 96492dc commit 3247058
Show file tree
Hide file tree
Showing 26 changed files with 227 additions and 223 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ TEST_F(AcirFormatTests, TestLogicGateFromNoirCircuit)
};

LogicConstraint logic_constraint{
.a = 0,
.b = 1,
.a = WitnessOrConstant<bb::fr>::from_index(0),
.b = WitnessOrConstant<bb::fr>::from_index(1),
.result = 2,
.num_bits = 32,
.is_xor_gate = 1,
Expand Down Expand Up @@ -510,7 +510,33 @@ TEST_F(AcirFormatTests, TestKeccakPermutation)
{
Keccakf1600
keccak_permutation{
.state = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 },
.state = {
WitnessOrConstant<bb::fr>::from_index(1),
WitnessOrConstant<bb::fr>::from_index(2),
WitnessOrConstant<bb::fr>::from_index(3),
WitnessOrConstant<bb::fr>::from_index(4),
WitnessOrConstant<bb::fr>::from_index(5),
WitnessOrConstant<bb::fr>::from_index(6),
WitnessOrConstant<bb::fr>::from_index(7),
WitnessOrConstant<bb::fr>::from_index(8),
WitnessOrConstant<bb::fr>::from_index(9),
WitnessOrConstant<bb::fr>::from_index(10),
WitnessOrConstant<bb::fr>::from_index(11),
WitnessOrConstant<bb::fr>::from_index(12),
WitnessOrConstant<bb::fr>::from_index(13),
WitnessOrConstant<bb::fr>::from_index(14),
WitnessOrConstant<bb::fr>::from_index(15),
WitnessOrConstant<bb::fr>::from_index(16),
WitnessOrConstant<bb::fr>::from_index(17),
WitnessOrConstant<bb::fr>::from_index(18),
WitnessOrConstant<bb::fr>::from_index(19),
WitnessOrConstant<bb::fr>::from_index(20),
WitnessOrConstant<bb::fr>::from_index(21),
WitnessOrConstant<bb::fr>::from_index(22),
WitnessOrConstant<bb::fr>::from_index(23),
WitnessOrConstant<bb::fr>::from_index(24),
WitnessOrConstant<bb::fr>::from_index(25),
},
.result = { 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50 },
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,61 +195,34 @@ uint32_t get_witness_from_function_input(Program::FunctionInput input)
return input_witness.value.value;
}

WitnessConstant<bb::fr> parse_input(Program::FunctionInput input)
WitnessOrConstant<bb::fr> parse_input(Program::FunctionInput input)
{
WitnessConstant result = std::visit(
WitnessOrConstant result = std::visit(
[&](auto&& e) {
using T = std::decay_t<decltype(e)>;
if constexpr (std::is_same_v<T, Program::ConstantOrWitnessEnum::Witness>) {
return WitnessConstant<bb::fr>{
return WitnessOrConstant<bb::fr>{
.index = e.value.value,
.value = bb::fr::zero(),
.is_constant = false,
};
} else if constexpr (std::is_same_v<T, Program::ConstantOrWitnessEnum::Constant>) {
return WitnessConstant<bb::fr>{
return WitnessOrConstant<bb::fr>{
.index = 0,
.value = uint256_t(e.value),
.is_constant = true,
};
} else {
ASSERT(false);
}
return WitnessConstant<bb::fr>{
return WitnessOrConstant<bb::fr>{
.index = 0,
.value = bb::fr::zero(),
.is_constant = true,
};
},
input.input.value);
return result;

// WitnessConstant result = std::visit(
// [&](auto&& e) {
// using T = std::decay_t<decltype(e)>;
// if constexpr (std::is_same_v<T, Program::FunctionInput::Witness>) {
// return WitnessConstant<bb::fr>{
// .index = e.value.witness.value,
// .value = bb::fr::zero(),
// .is_constant = false,
// };
// } else if constexpr (std::is_same_v<T, Program::FunctionInput::Constant>) {
// return WitnessConstant<bb::fr>{
// .index = 0,
// .value = uint256_t(e.value.constant),
// .is_constant = true,
// };
// } else {
// ASSERT(false);
// }
// return WitnessConstant<bb::fr>{
// .index = 0,
// .value = bb::fr::zero(),
// .is_constant = true,
// };
// },
// input.value);
// return result;
}

void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
Expand All @@ -261,8 +234,8 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
[&](auto&& arg) {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::AND>) {
auto lhs_input = get_witness_from_function_input(arg.lhs);
auto rhs_input = get_witness_from_function_input(arg.rhs);
auto lhs_input = parse_input(arg.lhs);
auto rhs_input = parse_input(arg.rhs);
af.logic_constraints.push_back(LogicConstraint{
.a = lhs_input,
.b = rhs_input,
Expand All @@ -272,8 +245,8 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
});
af.original_opcode_indices.logic_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::XOR>) {
auto lhs_input = get_witness_from_function_input(arg.lhs);
auto rhs_input = get_witness_from_function_input(arg.rhs);
auto lhs_input = parse_input(arg.lhs);
auto rhs_input = parse_input(arg.rhs);
af.logic_constraints.push_back(LogicConstraint{
.a = lhs_input,
.b = rhs_input,
Expand All @@ -292,29 +265,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,

} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::AES128Encrypt>) {
af.aes128_constraints.push_back(AES128Constraint{
.inputs = map(arg.inputs,
[](auto& e) {
return AES128Input{
.witness = get_witness_from_function_input(e),
.num_bits = e.num_bits,
};
}),
.iv = map(arg.iv,
[](auto& e) {
auto witness = get_witness_from_function_input(e);
return AES128Input{
.witness = witness,
.num_bits = e.num_bits,
};
}),
.key = map(arg.key,
[](auto& e) {
auto input_witness = get_witness_from_function_input(e);
return AES128Input{
.witness = input_witness,
.num_bits = e.num_bits,
};
}),
.inputs = map(arg.inputs, [](auto& e) { return parse_input(e); }),
.iv = map(arg.iv, [](auto& e) { return parse_input(e); }),
.key = map(arg.key, [](auto& e) { return parse_input(e); }),
.outputs = map(arg.outputs, [](auto& e) { return e.value; }),
});
af.original_opcode_indices.aes128_constraints.push_back(opcode_index);
Expand All @@ -335,32 +288,17 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,

} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::Sha256Compression>) {
af.sha256_compression.push_back(Sha256Compression{
.inputs = map(arg.inputs,
[](auto& e) {
auto input_witness = get_witness_from_function_input(e);
return Sha256Input{
.witness = input_witness,
.num_bits = e.num_bits,
};
}),
.hash_values = map(arg.hash_values,
[](auto& e) {
auto input_witness = get_witness_from_function_input(e);
return Sha256Input{
.witness = input_witness,
.num_bits = e.num_bits,
};
}),
.inputs = map(arg.inputs, [](auto& e) { return parse_input(e); }),
.hash_values = map(arg.hash_values, [](auto& e) { return parse_input(e); }),
.result = map(arg.outputs, [](auto& e) { return e.value; }),
});
af.original_opcode_indices.sha256_compression.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::Blake2s>) {
af.blake2s_constraints.push_back(Blake2sConstraint{
.inputs = map(arg.inputs,
[](auto& e) {
auto input_witness = get_witness_from_function_input(e);
return Blake2sInput{
.witness = input_witness,
.blackbox_input = parse_input(e),
.num_bits = e.num_bits,
};
}),
Expand All @@ -371,9 +309,8 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
af.blake3_constraints.push_back(Blake3Constraint{
.inputs = map(arg.inputs,
[](auto& e) {
auto input_witness = get_witness_from_function_input(e);
return Blake3Input{
.witness = input_witness,
.blackbox_input = parse_input(e),
.num_bits = e.num_bits,
};
}),
Expand Down Expand Up @@ -437,12 +374,12 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
});
af.original_opcode_indices.multi_scalar_mul_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::EmbeddedCurveAdd>) {
auto input_1_x = get_witness_from_function_input(arg.input1[0]);
auto input_1_y = get_witness_from_function_input(arg.input1[1]);
auto input_1_infinite = get_witness_from_function_input(arg.input1[2]);
auto input_2_x = get_witness_from_function_input(arg.input2[0]);
auto input_2_y = get_witness_from_function_input(arg.input2[1]);
auto input_2_infinite = get_witness_from_function_input(arg.input2[2]);
auto input_1_x = parse_input(arg.input1[0]);
auto input_1_y = parse_input(arg.input1[1]);
auto input_1_infinite = parse_input(arg.input1[2]);
auto input_2_x = parse_input(arg.input2[0]);
auto input_2_y = parse_input(arg.input2[1]);
auto input_2_infinite = parse_input(arg.input2[2]);

af.ec_add_constraints.push_back(EcAdd{
.input1_x = input_1_x,
Expand Down Expand Up @@ -473,11 +410,7 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
af.original_opcode_indices.keccak_constraints.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::Keccakf1600>) {
af.keccak_permutations.push_back(Keccakf1600{
.state = map(arg.inputs,
[](auto& e) {
auto input_witness = get_witness_from_function_input(e);
return input_witness;
}),
.state = map(arg.inputs, [](auto& e) { return parse_input(e); }),
.result = map(arg.outputs, [](auto& e) { return e.value; }),
});
af.original_opcode_indices.keccak_permutations.push_back(opcode_index);
Expand Down Expand Up @@ -551,11 +484,7 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg,
af.original_opcode_indices.bigint_operations.push_back(opcode_index);
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::Poseidon2Permutation>) {
af.poseidon2_constraints.push_back(Poseidon2Constraint{
.state = map(arg.inputs,
[](auto& e) {
auto input_witness = get_witness_from_function_input(e);
return input_witness;
}),
.state = map(arg.inputs, [](auto& e) { return parse_input(e); }),
.result = map(arg.outputs, [](auto& e) { return e.value; }),
.len = arg.len,
});
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "aes128_constraint.hpp"
#include "barretenberg/dsl/acir_format/acir_format.hpp"
#include "barretenberg/stdlib/encryption/aes128/aes128.hpp"
#include <cstdint>
#include <cstdio>
Expand All @@ -14,20 +15,21 @@ template <typename Builder> void create_aes128_constraints(Builder& builder, con
using field_ct = bb::stdlib::field_t<Builder>;

// Packs 16 bytes from the inputs (plaintext, iv, key) into a field element
const auto convert_input = [&](std::span<const AES128Input, std::dynamic_extent> inputs, size_t padding) {
field_ct converted = 0;
for (size_t i = 0; i < 16 - padding; ++i) {
converted *= 256;
field_ct byte = field_ct::from_witness_index(&builder, inputs[i].witness);
converted += byte;
}
for (size_t i = 0; i < padding; ++i) {
converted *= 256;
field_ct byte = padding;
converted += byte;
}
return converted;
};
const auto convert_input =
[&](std::span<const WitnessOrConstant<bb::fr>, std::dynamic_extent> inputs, size_t padding, Builder& builder) {
field_ct converted = 0;
for (size_t i = 0; i < 16 - padding; ++i) {
converted *= 256;
field_ct byte = to_field_ct(inputs[i], builder);
converted += byte;
}
for (size_t i = 0; i < padding; ++i) {
converted *= 256;
field_ct byte = padding;
converted += byte;
}
return converted;
};

// Packs 16 bytes from the outputs (witness indexes) into a field element for comparison
const auto convert_output = [&](std::span<const uint32_t, 16> outputs) {
Expand All @@ -47,11 +49,14 @@ template <typename Builder> void create_aes128_constraints(Builder& builder, con
for (size_t i = 0; i < constraint.inputs.size(); i += 16) {
field_ct to_add;
if (i + 16 > constraint.inputs.size()) {
to_add = convert_input(
std::span<const AES128Input, std::dynamic_extent>{ &constraint.inputs[i], 16 - padding_size },
padding_size);
to_add =
convert_input(std::span<const WitnessOrConstant<bb::fr>, std::dynamic_extent>{ &constraint.inputs[i],
16 - padding_size },
padding_size,
builder);
} else {
to_add = convert_input(std::span<const AES128Input, 16>{ &constraint.inputs[i], 16 }, 0);
to_add =
convert_input(std::span<const WitnessOrConstant<bb::fr>, 16>{ &constraint.inputs[i], 16 }, 0, builder);
}
converted_inputs.emplace_back(to_add);
}
Expand All @@ -63,7 +68,7 @@ template <typename Builder> void create_aes128_constraints(Builder& builder, con
}

const std::vector<field_ct> output_bytes = bb::stdlib::aes128::encrypt_buffer_cbc<Builder>(
converted_inputs, convert_input(constraint.iv, 0), convert_input(constraint.key, 0));
converted_inputs, convert_input(constraint.iv, 0, builder), convert_input(constraint.key, 0, builder));

for (size_t i = 0; i < output_bytes.size(); ++i) {
builder.assert_equal(output_bytes[i].normalize().witness_index, converted_outputs[i].normalize().witness_index);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once
#include "barretenberg/dsl/acir_format/witness_constant.hpp"
#include "barretenberg/serialize/msgpack.hpp"
#include "barretenberg/stdlib/primitives/field/field.hpp"
#include <array>
#include <cstdint>
#include <vector>
Expand All @@ -16,9 +18,9 @@ struct AES128Input {
};

struct AES128Constraint {
std::vector<AES128Input> inputs;
std::array<AES128Input, 16> iv;
std::array<AES128Input, 16> key;
std::vector<WitnessOrConstant<bb::fr>> inputs;
std::array<WitnessOrConstant<bb::fr>, 16> iv;
std::array<WitnessOrConstant<bb::fr>, 16> key;
std::vector<uint32_t> outputs;

// For serialization, update with any new fields
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ template <typename Builder> void create_blake2s_constraints(Builder& builder, co
// Get the witness assignment for each witness index
// Write the witness assignment to the byte_array
for (const auto& witness_index_num_bits : constraint.inputs) {
auto witness_index = witness_index_num_bits.witness;
auto witness_index = witness_index_num_bits.blackbox_input;
auto num_bits = witness_index_num_bits.num_bits;

// XXX: The implementation requires us to truncate the element to the nearest byte and not bit
auto num_bytes = round_to_nearest_byte(num_bits);

field_ct element = field_ct::from_witness_index(&builder, witness_index);
field_ct element = to_field_ct(witness_index, builder);
byte_array_ct element_bytes(element, num_bytes);

arr.write(element_bytes);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#pragma once
#include "barretenberg/dsl/acir_format/witness_constant.hpp"
#include "barretenberg/serialize/msgpack.hpp"
#include <array>
#include <cstdint>
Expand All @@ -7,11 +8,11 @@
namespace acir_format {

struct Blake2sInput {
uint32_t witness;
WitnessOrConstant<bb::fr> blackbox_input;
uint32_t num_bits;

// For serialization, update with any new fields
MSGPACK_FIELDS(witness, num_bits);
MSGPACK_FIELDS(blackbox_input, num_bits);
friend bool operator==(Blake2sInput const& lhs, Blake2sInput const& rhs) = default;
};

Expand Down
Loading

1 comment on commit 3247058

@AztecBot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'C++ Benchmark'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 1.05.

Benchmark suite Current: 3247058 Previous: c2ccaea Ratio
nativeconstruct_proof_ultrahonk_power_of_2/20 5030.665084999995 ms/iter 4780.158285000013 ms/iter 1.05

This comment was automatically generated by workflow using github-action-benchmark.

CC: @ludamad @codygunton

Please sign in to comment.