Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: aggregate with short scalars in UH Recursion #11478

Merged
merged 18 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ UltraRecursiveVerifier_<Flavor>::Output UltraRecursiveVerifier_<Flavor>::verify_
// TODO(https://github.com/AztecProtocol/barretenberg/issues/995): generate this challenge properly.
typename Curve::ScalarField recursion_separator =
Curve::ScalarField::from_witness_index(builder, builder->add_variable(42));
agg_obj.aggregate(nested_agg_obj, recursion_separator);
agg_obj.template aggregate<Builder>(nested_agg_obj, recursion_separator);

// Execute Sumcheck Verifier and extract multivariate opening point u = (u_0, ..., u_{d-1}) and purported
// multivariate evaluations at u
Expand Down Expand Up @@ -143,11 +143,11 @@ UltraRecursiveVerifier_<Flavor>::Output UltraRecursiveVerifier_<Flavor>::verify_
pairing_points[0] = pairing_points[0].normalize();
pairing_points[1] = pairing_points[1].normalize();
// TODO(https://github.com/AztecProtocol/barretenberg/issues/995): generate recursion separator challenge properly.
agg_obj.aggregate(pairing_points, recursion_separator);
agg_obj.template aggregate<Builder>(pairing_points, recursion_separator);
output.agg_obj = std::move(agg_obj);

// Extract the IPA claim from the public inputs
// Parse out the nested IPA claim using key->ipa_claim_public_input_indices and runs the native IPA verifier.
// Parse out the nested IPA claim using key->ipa_claim_public_input_indices and run the native IPA verifier.
if constexpr (HasIPAAccumulator<Flavor>) {
const auto recover_fq_from_public_inputs = [](std::array<FF, Curve::BaseField::NUM_LIMBS>& limbs) {
for (size_t k = 0; k < Curve::BaseField::NUM_LIMBS; k++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,36 @@ template <typename Curve> struct aggregation_state {
{
return P0 == other.P0 && P1 == other.P1;
};

template <typename BuilderType = void>
void aggregate(aggregation_state const& other, typename Curve::ScalarField recursion_separator)
{
P0 += other.P0 * recursion_separator;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this costs >140K per aggregate call

P1 += other.P1 * recursion_separator;
if constexpr (std::is_same_v<BuilderType, MegaCircuitBuilder>) {
P0 += other.P0 * recursion_separator;
P1 += other.P1 * recursion_separator;
} else {
// Save gates using short scalars. We don't apply `bn254_endo_batch_mul` to the vector {1,
// recursion_separator} directly to avoid edge cases.
typename Curve::Group point_to_aggregate = other.P0.scalar_mul(recursion_separator, 128);
P0 += point_to_aggregate;
point_to_aggregate = other.P1.scalar_mul(recursion_separator, 128);
P1 += point_to_aggregate;
}
}

template <typename BuilderType = void>
void aggregate(std::array<typename Curve::Group, 2> const& other, typename Curve::ScalarField recursion_separator)
{
P0 += other[0] * recursion_separator;
P1 += other[1] * recursion_separator;
if constexpr (std::is_same_v<BuilderType, MegaCircuitBuilder>) {
P0 += other[0] * recursion_separator;
P1 += other[1] * recursion_separator;
} else {
// Save gates using short scalars. We don't apply `bn254_endo_batch_mul` to the vector {1,
// recursion_separator} directly to avoid edge cases.
typename Curve::Group point_to_aggregate = other[0].scalar_mul(recursion_separator, 128);
P0 += point_to_aggregate;
point_to_aggregate = other[1].scalar_mul(recursion_separator, 128);
P1 += point_to_aggregate;
}
}

PairingPointAccumulatorIndices get_witness_indices()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {
result.y.assert_is_in_field();
return result;
}
element scalar_mul(const Fr& scalar, const size_t max_num_bits = 0) const;

element reduce() const
{
Expand Down Expand Up @@ -525,7 +526,10 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {
num_fives = num_points / 5;
num_sixes = 0;
// size-6 table is expensive and only benefits us if creating them reduces the number of total tables
if (num_fives * 5 == (num_points - 1)) {
if (num_points == 1) {
Copy link
Contributor Author

@iakovenkos iakovenkos Feb 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle the case of a single point to re-use this function in the * operator and in scalar_mul method

Probably there are other edge cases

num_fives = 0;
num_sixes = 0;
} else if (num_fives * 5 == (num_points - 1)) {
num_fives -= 1;
num_sixes = 1;
} else if (num_fives * 5 == (num_points - 2) && num_fives >= 2) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,142 @@ template <typename TestType> class stdlib_biggroup : public testing::Test {
EXPECT_CIRCUIT_CORRECTNESS(builder);
}

// Test short scalar mul with variable even bit length. For efficiency, it's split into two tests.
static void test_short_scalar_mul_2_126()
{
Builder builder;
const size_t max_num_bits = 128;

// We only test even bit lengths, because `bn254_endo_batch_mul` used in 'scalar_mul' can't handle odd lengths.
for (size_t i = 2; i < max_num_bits; i += 2) {
affine_element input(element::random_element());
// Get a random 256 integer
uint256_t scalar_raw = engine.get_random_uint256();
// Produce a length =< i scalar.
scalar_raw = scalar_raw >> (256 - i);
fr scalar = fr(scalar_raw);

// Avoid multiplication by 0 that may occur when `i` is small
if (scalar == fr(0)) {
scalar += 1;
};

element_ct P = element_ct::from_witness(&builder, input);
scalar_ct x = scalar_ct::from_witness(&builder, scalar);

// Set input tags
x.set_origin_tag(challenge_origin_tag);
P.set_origin_tag(submitted_value_origin_tag);

std::cerr << "gates before mul " << builder.get_estimated_num_finalized_gates() << std::endl;
// Multiply using specified scalar length
element_ct c = P.scalar_mul(x, i);
std::cerr << "builder aftr mul " << builder.get_estimated_num_finalized_gates() << std::endl;
affine_element c_expected(element(input) * scalar);

// Check the result of the multiplication has a tag that's the union of inputs' tags
EXPECT_EQ(c.get_origin_tag(), first_two_merged_tag);
fq c_x_result(c.x.get_value().lo);
fq c_y_result(c.y.get_value().lo);

EXPECT_EQ(c_x_result, c_expected.x);

EXPECT_EQ(c_y_result, c_expected.y);
}

EXPECT_CIRCUIT_CORRECTNESS(builder);
}

static void test_short_scalar_mul_128_252()
{
Builder builder;
const size_t max_num_bits = 254;

// We only test even bit lengths, because `bn254_endo_batch_mul` used in 'scalar_mul' can't handle odd lengths.
for (size_t i = 128; i < max_num_bits; i += 2) {
affine_element input(element::random_element());
// Get a random 256-bit integer
uint256_t scalar_raw = engine.get_random_uint256();
// Produce a length =< i scalar.
scalar_raw = scalar_raw >> (256 - i);
fr scalar = fr(scalar_raw);

element_ct P = element_ct::from_witness(&builder, input);
scalar_ct x = scalar_ct::from_witness(&builder, scalar);

// Set input tags
x.set_origin_tag(challenge_origin_tag);
P.set_origin_tag(submitted_value_origin_tag);

std::cerr << "gates before mul " << builder.get_estimated_num_finalized_gates() << std::endl;
// Multiply using specified scalar length
element_ct c = P.scalar_mul(x, i);
std::cerr << "builder aftr mul " << builder.get_estimated_num_finalized_gates() << std::endl;
affine_element c_expected(element(input) * scalar);

// Check the result of the multiplication has a tag that's the union of inputs' tags
EXPECT_EQ(c.get_origin_tag(), first_two_merged_tag);
fq c_x_result(c.x.get_value().lo);
fq c_y_result(c.y.get_value().lo);

EXPECT_EQ(c_x_result, c_expected.x);

EXPECT_EQ(c_y_result, c_expected.y);
}

EXPECT_CIRCUIT_CORRECTNESS(builder);
}

static void test_short_scalar_mul_infinity()
{
// We check that a point at infinity preserves `is_point_at_infinity()` flag after being multiplied against a
// short scalar and also check that the number of gates in this case is equal to the number of gates spent on a
// finite point.

// Populate test points.
std::vector<element> points(2);

points[0] = element::infinity();
points[1] = element::random_element();
// Containter for gate counts.
std::vector<size_t> gates(2);

// We initialize this flag as `true`, because the first result is expected to be the point at infinity.
bool expect_infinity = true;

for (auto [point, num_gates] : zip_view(points, gates)) {
Builder builder;

const size_t max_num_bits = 128;
// Get a random 256-bit integer
uint256_t scalar_raw = engine.get_random_uint256();
// Produce a length =< max_num_bits scalar.
scalar_raw = scalar_raw >> (256 - max_num_bits);
fr scalar = fr(scalar_raw);

element_ct P = element_ct::from_witness(&builder, point);
scalar_ct x = scalar_ct::from_witness(&builder, scalar);

// Set input tags
x.set_origin_tag(challenge_origin_tag);
P.set_origin_tag(submitted_value_origin_tag);

std::cerr << "gates before mul " << builder.get_estimated_num_finalized_gates() << std::endl;
element_ct c = P.scalar_mul(x, max_num_bits);
std::cerr << "builder aftr mul " << builder.get_estimated_num_finalized_gates() << std::endl;
num_gates = builder.get_estimated_num_finalized_gates();
// Check the result of the multiplication has a tag that's the union of inputs' tags
EXPECT_EQ(c.get_origin_tag(), first_two_merged_tag);

EXPECT_EQ(c.is_point_at_infinity().get_value(), expect_infinity);
EXPECT_CIRCUIT_CORRECTNESS(builder);
// The second point is finite, hence we flip the flag
expect_infinity = false;
}
// Check that the numbers of gates are equal in both cases.
EXPECT_EQ(gates[0], gates[1]);
}

static void test_twin_mul()
{
Builder builder;
Expand Down Expand Up @@ -950,26 +1086,39 @@ template <typename TestType> class stdlib_biggroup : public testing::Test {
static void test_compute_naf()
{
Builder builder = Builder();
size_t num_repetitions(32);
for (size_t i = 0; i < num_repetitions; i++) {
fr scalar_val = fr::random_element();
size_t max_num_bits = 254;
// Our design of NAF and the way it is used assumes the even length of scalars.
for (size_t length = 2; length < max_num_bits; length += 2) {

fr scalar_val;

uint256_t scalar_raw = engine.get_random_uint256();
scalar_raw = scalar_raw >> (256 - length);

scalar_val = fr(scalar_raw);

// NAF with short scalars doesn't handle 0
if (scalar_val == fr(0)) {
scalar_val += 1;
};
scalar_ct scalar = scalar_ct::from_witness(&builder, scalar_val);
// Set tag for scalar
scalar.set_origin_tag(submitted_value_origin_tag);
auto naf = element_ct::compute_naf(scalar);
auto naf = element_ct::compute_naf(scalar, length);

for (const auto& bit : naf) {
// Check that the tag is propagated to bits
EXPECT_EQ(bit.get_origin_tag(), submitted_value_origin_tag);
}
// scalar = -naf[254] + \sum_{i=0}^{253}(1-2*naf[i]) 2^{253-i}
fr reconstructed_val(0);
for (size_t i = 0; i < 254; i++) {
reconstructed_val += (fr(1) - fr(2) * fr(naf[i].witness_bool)) * fr(uint256_t(1) << (253 - i));
for (size_t i = 0; i < length; i++) {
reconstructed_val += (fr(1) - fr(2) * fr(naf[i].witness_bool)) * fr(uint256_t(1) << (length - 1 - i));
};
reconstructed_val -= fr(naf[254].witness_bool);
reconstructed_val -= fr(naf[length].witness_bool);
EXPECT_EQ(scalar_val, reconstructed_val);
}

EXPECT_CIRCUIT_CORRECTNESS(builder);
}

Expand Down Expand Up @@ -1614,6 +1763,33 @@ HEAVY_TYPED_TEST(stdlib_biggroup, mul)
{
TestFixture::test_mul();
}

HEAVY_TYPED_TEST(stdlib_biggroup, short_scalar_mul_2_126_bits)
{
if constexpr (HasGoblinBuilder<TypeParam>) {
GTEST_SKIP();
} else {
TestFixture::test_short_scalar_mul_2_126();
}
}
HEAVY_TYPED_TEST(stdlib_biggroup, short_scalar_mul_128_252_bits)
{
if constexpr (HasGoblinBuilder<TypeParam>) {
GTEST_SKIP();
} else {
TestFixture::test_short_scalar_mul_128_252();
}
}

HEAVY_TYPED_TEST(stdlib_biggroup, short_scalar_mul_infinity)
{
if constexpr (HasGoblinBuilder<TypeParam>) {
GTEST_SKIP();
} else {
TestFixture::test_short_scalar_mul_infinity();
}
}

HEAVY_TYPED_TEST(stdlib_biggroup, twin_mul)
{
if constexpr (HasGoblinBuilder<TypeParam>) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ element<C, Fq, Fr, G> element<C, Fq, Fr, G>::bn254_endo_batch_mul(const std::vec
const std::vector<Fr>& small_scalars,
const size_t max_num_small_bits)
{
ASSERT(max_num_small_bits >= 128);

ASSERT(max_num_small_bits % 2 == 0);

const size_t num_big_points = big_points.size();
const size_t num_small_points = small_points.size();
C* ctx = nullptr;
Expand Down
Loading