Skip to content

Commit

Permalink
feat!: add is_infinite to curve addition opcode (#6384)
Browse files Browse the repository at this point in the history
Resolves noir-lang/noir#4978

Since elliptic curve addition in barretenberg is already handling the
point at infinity, I simply expose it in the ACIR opcode.
  • Loading branch information
guipublic authored May 17, 2024
1 parent 26f2197 commit 75d81c5
Show file tree
Hide file tree
Showing 30 changed files with 494 additions and 211 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -342,15 +342,19 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, Aci
.scalars = map(arg.scalars, [](auto& e) { return e.witness.value; }),
.out_point_x = arg.outputs[0].value,
.out_point_y = arg.outputs[1].value,
.out_point_is_infinite = arg.outputs[2].value,
});
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::EmbeddedCurveAdd>) {
af.ec_add_constraints.push_back(EcAdd{
.input1_x = arg.input1_x.witness.value,
.input1_y = arg.input1_y.witness.value,
.input2_x = arg.input2_x.witness.value,
.input2_y = arg.input2_y.witness.value,
.input1_x = arg.input1[0].witness.value,
.input1_y = arg.input1[1].witness.value,
.input1_infinite = arg.input1[2].witness.value,
.input2_x = arg.input2[0].witness.value,
.input2_y = arg.input2[1].witness.value,
.input2_infinite = arg.input2[2].witness.value,
.result_x = arg.outputs[0].value,
.result_y = arg.outputs[1].value,
.result_infinite = arg.outputs[2].value,
});
} else if constexpr (std::is_same_v<T, Program::BlackBoxFuncCall::Keccak256>) {
af.keccak_constraints.push_back(KeccakConstraint{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,35 @@ void create_ec_add_constraint(Builder& builder, const EcAdd& input, bool has_val
// Input to cycle_group points
using cycle_group_ct = bb::stdlib::cycle_group<Builder>;
using field_ct = bb::stdlib::field_t<Builder>;
using bool_ct = bb::stdlib::bool_t<Builder>;

auto x1 = field_ct::from_witness_index(&builder, input.input1_x);
auto y1 = field_ct::from_witness_index(&builder, input.input1_y);
auto x2 = field_ct::from_witness_index(&builder, input.input2_x);
auto y2 = field_ct::from_witness_index(&builder, input.input2_y);
auto infinite1 = bool_ct(field_ct::from_witness_index(&builder, input.input1_infinite));
auto infinite2 = bool_ct(field_ct::from_witness_index(&builder, input.input2_infinite));
if (!has_valid_witness_assignments) {
auto g1 = grumpkin::g1::affine_one;
// We need to have correct values representing points on the curve
builder.variables[input.input1_x] = g1.x;
builder.variables[input.input1_y] = g1.y;
builder.variables[input.input1_infinite] = fr(0);
builder.variables[input.input2_x] = g1.x;
builder.variables[input.input2_y] = g1.y;
builder.variables[input.input2_infinite] = fr(0);
}

cycle_group_ct input1_point(x1, y1, false);
cycle_group_ct input2_point(x2, y2, false);

cycle_group_ct input1_point(x1, y1, infinite1);
cycle_group_ct input2_point(x2, y2, infinite2);
// Addition
cycle_group_ct result = input1_point + input2_point;

auto x_normalized = result.x.normalize();
auto y_normalized = result.y.normalize();
auto infinite = result.is_point_at_infinity().normalize();
builder.assert_equal(x_normalized.witness_index, input.result_x);
builder.assert_equal(y_normalized.witness_index, input.result_y);
builder.assert_equal(infinite.witness_index, input.result_infinite);
}

template void create_ec_add_constraint<UltraCircuitBuilder>(UltraCircuitBuilder& builder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@ namespace acir_format {
struct EcAdd {
uint32_t input1_x;
uint32_t input1_y;
uint32_t input1_infinite;
uint32_t input2_x;
uint32_t input2_y;
uint32_t input2_infinite;
uint32_t result_x;
uint32_t result_y;
uint32_t result_infinite;

// for serialization, update with any new fields
MSGPACK_FIELDS(input1_x, input1_y, input2_x, input2_y, result_x, result_y);
MSGPACK_FIELDS(
input1_x, input1_y, input1_infinite, input2_x, input2_y, input2_infinite, result_x, result_y, result_infinite);
friend bool operator==(EcAdd const& lhs, EcAdd const& rhs) = default;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,18 @@ size_t generate_ec_add_constraint(EcAdd& ec_add_constraint, WitnessVector& witne
witness_values.push_back(g1.y);
witness_values.push_back(result.x.get_value());
witness_values.push_back(result.y.get_value());
witness_values.push_back(fr(0));
witness_values.push_back(fr(0));
ec_add_constraint = EcAdd{
.input1_x = 1,
.input1_y = 2,
.input1_infinite = 7,
.input2_x = 3,
.input2_y = 4,
.input2_infinite = 7,
.result_x = 5,
.result_y = 6,
.result_infinite = 8,
};
return witness_values.size();
}
Expand Down Expand Up @@ -85,6 +90,92 @@ TEST_F(EcOperations, TestECOperations)
auto prover = composer.create_prover(builder);

auto proof = prover.construct_proof();

EXPECT_TRUE(CircuitChecker::check(builder));
auto verifier = composer.create_verifier(builder);
EXPECT_EQ(verifier.verify_proof(proof), true);
}

TEST_F(EcOperations, TestECMultiScalarMul)
{
MultiScalarMul msm_constrain;

WitnessVector witness_values;
witness_values.emplace_back(fr(0));

witness_values = {
// dummy
fr(0),
// g1: x,y,infinite
fr(1),
fr("0x0000000000000002cf135e7506a45d632d270d45f1181294833fc48d823f272c"),
fr(0),
// low, high scalars
fr(1),
fr(0),
// result
fr("0x06ce1b0827aafa85ddeb49cdaa36306d19a74caa311e13d46d8bc688cdbffffe"),
fr("0x1c122f81a3a14964909ede0ba2a6855fc93faf6fa1a788bf467be7e7a43f80ac"),
fr(0),
};
msm_constrain = MultiScalarMul{
.points = { 1, 2, 3, 1, 2, 3 },
.scalars = { 4, 5, 4, 5 },
.out_point_x = 6,
.out_point_y = 7,
.out_point_is_infinite = 0,
};
auto res_x = fr("0x06ce1b0827aafa85ddeb49cdaa36306d19a74caa311e13d46d8bc688cdbffffe");
auto assert_equal = poly_triple{
.a = 6,
.b = 0,
.c = 0,
.q_m = 0,
.q_l = fr::neg_one(),
.q_r = 0,
.q_o = 0,
.q_c = res_x,
};

size_t num_variables = witness_values.size();
AcirFormat constraint_system{
.varnum = static_cast<uint32_t>(num_variables + 1),
.recursive = false,
.num_acir_opcodes = 1,
.public_inputs = {},
.logic_constraints = {},
.range_constraints = {},
.aes128_constraints = {},
.sha256_constraints = {},
.sha256_compression = {},
.schnorr_constraints = {},
.ecdsa_k1_constraints = {},
.ecdsa_r1_constraints = {},
.blake2s_constraints = {},
.blake3_constraints = {},
.keccak_constraints = {},
.keccak_permutations = {},
.pedersen_constraints = {},
.pedersen_hash_constraints = {},
.poseidon2_constraints = {},
.multi_scalar_mul_constraints = { msm_constrain },
.ec_add_constraints = {},
.recursion_constraints = {},
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.poly_triple_constraints = { assert_equal },
.quad_constraints = {},
.block_constraints = {},
};

auto builder = create_circuit(constraint_system, /*size_hint*/ 0, witness_values);

auto composer = Composer();
auto prover = composer.create_prover(builder);

auto proof = prover.construct_proof();

EXPECT_TRUE(CircuitChecker::check(builder));
auto verifier = composer.create_verifier(builder);
EXPECT_EQ(verifier.verify_proof(proof), true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@ template <typename Builder> void create_multi_scalar_mul_constraint(Builder& bui
using cycle_group_ct = bb::stdlib::cycle_group<Builder>;
using cycle_scalar_ct = typename bb::stdlib::cycle_group<Builder>::cycle_scalar;
using field_ct = bb::stdlib::field_t<Builder>;
using bool_ct = bb::stdlib::bool_t<Builder>;

std::vector<cycle_group_ct> points;
std::vector<cycle_scalar_ct> scalars;

for (size_t i = 0; i < input.points.size(); i += 2) {
for (size_t i = 0; i < input.points.size(); i += 3) {
// Instantiate the input point/variable base as `cycle_group_ct`
auto point_x = field_ct::from_witness_index(&builder, input.points[i]);
auto point_y = field_ct::from_witness_index(&builder, input.points[i + 1]);
cycle_group_ct input_point(point_x, point_y, false);

auto infinite = bool_ct(field_ct::from_witness_index(&builder, input.points[i + 2]));
cycle_group_ct input_point(point_x, point_y, infinite);
// Reconstruct the scalar from the low and high limbs
field_ct scalar_low_as_field = field_ct::from_witness_index(&builder, input.scalars[i]);
field_ct scalar_high_as_field = field_ct::from_witness_index(&builder, input.scalars[i + 1]);
field_ct scalar_low_as_field = field_ct::from_witness_index(&builder, input.scalars[2 * (i / 3)]);
field_ct scalar_high_as_field = field_ct::from_witness_index(&builder, input.scalars[2 * (i / 3) + 1]);
cycle_scalar_ct scalar(scalar_low_as_field, scalar_high_as_field);

// Add the point and scalar to the vectors
Expand All @@ -38,6 +39,7 @@ template <typename Builder> void create_multi_scalar_mul_constraint(Builder& bui
// Add the constraints
builder.assert_equal(output_point.x.get_witness_index(), input.out_point_x);
builder.assert_equal(output_point.y.get_witness_index(), input.out_point_y);
builder.assert_equal(output_point.is_point_at_infinity().witness_index, input.out_point_is_infinite);
}

template void create_multi_scalar_mul_constraint<UltraCircuitBuilder>(UltraCircuitBuilder& builder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ struct MultiScalarMul {
std::vector<uint32_t> scalars;
uint32_t out_point_x;
uint32_t out_point_y;
uint32_t out_point_is_infinite;

// for serialization, update with any new fields
MSGPACK_FIELDS(points, scalars, out_point_x, out_point_y);
MSGPACK_FIELDS(points, scalars, out_point_x, out_point_y, out_point_is_infinite);
friend bool operator==(MultiScalarMul const& lhs, MultiScalarMul const& rhs) = default;
};

Expand Down
44 changes: 22 additions & 22 deletions barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,19 +149,17 @@ struct BlackBoxFuncCall {
struct MultiScalarMul {
std::vector<Program::FunctionInput> points;
std::vector<Program::FunctionInput> scalars;
std::array<Program::Witness, 2> outputs;
std::array<Program::Witness, 3> outputs;

friend bool operator==(const MultiScalarMul&, const MultiScalarMul&);
std::vector<uint8_t> bincodeSerialize() const;
static MultiScalarMul bincodeDeserialize(std::vector<uint8_t>);
};

struct EmbeddedCurveAdd {
Program::FunctionInput input1_x;
Program::FunctionInput input1_y;
Program::FunctionInput input2_x;
Program::FunctionInput input2_y;
std::array<Program::Witness, 2> outputs;
std::array<Program::FunctionInput, 3> input1;
std::array<Program::FunctionInput, 3> input2;
std::array<Program::Witness, 3> outputs;

friend bool operator==(const EmbeddedCurveAdd&, const EmbeddedCurveAdd&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -807,8 +805,10 @@ struct BlackBoxOp {
struct EmbeddedCurveAdd {
Program::MemoryAddress input1_x;
Program::MemoryAddress input1_y;
Program::MemoryAddress input1_infinite;
Program::MemoryAddress input2_x;
Program::MemoryAddress input2_y;
Program::MemoryAddress input2_infinite;
Program::HeapArray result;

friend bool operator==(const EmbeddedCurveAdd&, const EmbeddedCurveAdd&);
Expand Down Expand Up @@ -3194,16 +3194,10 @@ namespace Program {

inline bool operator==(const BlackBoxFuncCall::EmbeddedCurveAdd& lhs, const BlackBoxFuncCall::EmbeddedCurveAdd& rhs)
{
if (!(lhs.input1_x == rhs.input1_x)) {
if (!(lhs.input1 == rhs.input1)) {
return false;
}
if (!(lhs.input1_y == rhs.input1_y)) {
return false;
}
if (!(lhs.input2_x == rhs.input2_x)) {
return false;
}
if (!(lhs.input2_y == rhs.input2_y)) {
if (!(lhs.input2 == rhs.input2)) {
return false;
}
if (!(lhs.outputs == rhs.outputs)) {
Expand Down Expand Up @@ -3237,10 +3231,8 @@ template <typename Serializer>
void serde::Serializable<Program::BlackBoxFuncCall::EmbeddedCurveAdd>::serialize(
const Program::BlackBoxFuncCall::EmbeddedCurveAdd& obj, Serializer& serializer)
{
serde::Serializable<decltype(obj.input1_x)>::serialize(obj.input1_x, serializer);
serde::Serializable<decltype(obj.input1_y)>::serialize(obj.input1_y, serializer);
serde::Serializable<decltype(obj.input2_x)>::serialize(obj.input2_x, serializer);
serde::Serializable<decltype(obj.input2_y)>::serialize(obj.input2_y, serializer);
serde::Serializable<decltype(obj.input1)>::serialize(obj.input1, serializer);
serde::Serializable<decltype(obj.input2)>::serialize(obj.input2, serializer);
serde::Serializable<decltype(obj.outputs)>::serialize(obj.outputs, serializer);
}

Expand All @@ -3250,10 +3242,8 @@ Program::BlackBoxFuncCall::EmbeddedCurveAdd serde::Deserializable<
Program::BlackBoxFuncCall::EmbeddedCurveAdd>::deserialize(Deserializer& deserializer)
{
Program::BlackBoxFuncCall::EmbeddedCurveAdd obj;
obj.input1_x = serde::Deserializable<decltype(obj.input1_x)>::deserialize(deserializer);
obj.input1_y = serde::Deserializable<decltype(obj.input1_y)>::deserialize(deserializer);
obj.input2_x = serde::Deserializable<decltype(obj.input2_x)>::deserialize(deserializer);
obj.input2_y = serde::Deserializable<decltype(obj.input2_y)>::deserialize(deserializer);
obj.input1 = serde::Deserializable<decltype(obj.input1)>::deserialize(deserializer);
obj.input2 = serde::Deserializable<decltype(obj.input2)>::deserialize(deserializer);
obj.outputs = serde::Deserializable<decltype(obj.outputs)>::deserialize(deserializer);
return obj;
}
Expand Down Expand Up @@ -4638,12 +4628,18 @@ inline bool operator==(const BlackBoxOp::EmbeddedCurveAdd& lhs, const BlackBoxOp
if (!(lhs.input1_y == rhs.input1_y)) {
return false;
}
if (!(lhs.input1_infinite == rhs.input1_infinite)) {
return false;
}
if (!(lhs.input2_x == rhs.input2_x)) {
return false;
}
if (!(lhs.input2_y == rhs.input2_y)) {
return false;
}
if (!(lhs.input2_infinite == rhs.input2_infinite)) {
return false;
}
if (!(lhs.result == rhs.result)) {
return false;
}
Expand Down Expand Up @@ -4676,8 +4672,10 @@ void serde::Serializable<Program::BlackBoxOp::EmbeddedCurveAdd>::serialize(
{
serde::Serializable<decltype(obj.input1_x)>::serialize(obj.input1_x, serializer);
serde::Serializable<decltype(obj.input1_y)>::serialize(obj.input1_y, serializer);
serde::Serializable<decltype(obj.input1_infinite)>::serialize(obj.input1_infinite, serializer);
serde::Serializable<decltype(obj.input2_x)>::serialize(obj.input2_x, serializer);
serde::Serializable<decltype(obj.input2_y)>::serialize(obj.input2_y, serializer);
serde::Serializable<decltype(obj.input2_infinite)>::serialize(obj.input2_infinite, serializer);
serde::Serializable<decltype(obj.result)>::serialize(obj.result, serializer);
}

Expand All @@ -4689,8 +4687,10 @@ Program::BlackBoxOp::EmbeddedCurveAdd serde::Deserializable<Program::BlackBoxOp:
Program::BlackBoxOp::EmbeddedCurveAdd obj;
obj.input1_x = serde::Deserializable<decltype(obj.input1_x)>::deserialize(deserializer);
obj.input1_y = serde::Deserializable<decltype(obj.input1_y)>::deserialize(deserializer);
obj.input1_infinite = serde::Deserializable<decltype(obj.input1_infinite)>::deserialize(deserializer);
obj.input2_x = serde::Deserializable<decltype(obj.input2_x)>::deserialize(deserializer);
obj.input2_y = serde::Deserializable<decltype(obj.input2_y)>::deserialize(deserializer);
obj.input2_infinite = serde::Deserializable<decltype(obj.input2_infinite)>::deserialize(deserializer);
obj.result = serde::Deserializable<decltype(obj.result)>::deserialize(deserializer);
return obj;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ use dep::protocol_types::{
constants::GENERATOR_INDEX__SYMMETRIC_KEY, grumpkin_private_key::GrumpkinPrivateKey,
grumpkin_point::GrumpkinPoint, utils::arr_copy_slice
};
use dep::std::{hash::sha256, embedded_curve_ops::multi_scalar_mul};
use dep::std::{hash::sha256, embedded_curve_ops::{EmbeddedCurvePoint, EmbeddedCurveScalar, multi_scalar_mul}};

// TODO(#5726): This function is called deriveAESSecret in TS. I don't like point_to_symmetric_key name much since
// point is not the only input of the function. Unify naming with TS once we have a better name.
pub fn point_to_symmetric_key(secret: GrumpkinPrivateKey, point: GrumpkinPoint) -> [u8; 32] {
let shared_secret_fields = multi_scalar_mul([point.x, point.y], [secret.low, secret.high]);
let shared_secret_fields = multi_scalar_mul(
[EmbeddedCurvePoint { x: point.x, y: point.y, is_infinite: false }],
[EmbeddedCurveScalar { lo: secret.low, hi: secret.high }]
);
// TODO(https://github.com/AztecProtocol/aztec-packages/issues/6061): make the func return Point struct directly
let shared_secret = GrumpkinPoint::new(shared_secret_fields[0], shared_secret_fields[1]);
let mut shared_secret_bytes_with_separator = [0 as u8; 65];
Expand Down
Loading

0 comments on commit 75d81c5

Please sign in to comment.