Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: aggregate with short scalars in UH Recursion #11478

Merged
merged 18 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ UltraRecursiveVerifier_<Flavor>::Output UltraRecursiveVerifier_<Flavor>::verify_
// TODO(https://github.com/AztecProtocol/barretenberg/issues/995): generate this challenge properly.
typename Curve::ScalarField recursion_separator =
Curve::ScalarField::from_witness_index(builder, builder->add_variable(42));
agg_obj.aggregate(nested_agg_obj, recursion_separator);
agg_obj.template aggregate<Builder>(nested_agg_obj, recursion_separator);

// Execute Sumcheck Verifier and extract multivariate opening point u = (u_0, ..., u_{d-1}) and purported
// multivariate evaluations at u
Expand Down Expand Up @@ -143,11 +143,11 @@ UltraRecursiveVerifier_<Flavor>::Output UltraRecursiveVerifier_<Flavor>::verify_
pairing_points[0] = pairing_points[0].normalize();
pairing_points[1] = pairing_points[1].normalize();
// TODO(https://github.com/AztecProtocol/barretenberg/issues/995): generate recursion separator challenge properly.
agg_obj.aggregate(pairing_points, recursion_separator);
agg_obj.template aggregate<Builder>(pairing_points, recursion_separator);
output.agg_obj = std::move(agg_obj);

// Extract the IPA claim from the public inputs
// Parse out the nested IPA claim using key->ipa_claim_public_input_indices and runs the native IPA verifier.
// Parse out the nested IPA claim using key->ipa_claim_public_input_indices and run the native IPA verifier.
if constexpr (HasIPAAccumulator<Flavor>) {
const auto recover_fq_from_public_inputs = [](std::array<FF, Curve::BaseField::NUM_LIMBS>& limbs) {
for (size_t k = 0; k < Curve::BaseField::NUM_LIMBS; k++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,36 @@ template <typename Curve> struct aggregation_state {
{
return P0 == other.P0 && P1 == other.P1;
};

template <typename BuilderType = void>
void aggregate(aggregation_state const& other, typename Curve::ScalarField recursion_separator)
{
P0 += other.P0 * recursion_separator;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this costs >140K per aggregate call

P1 += other.P1 * recursion_separator;
if constexpr (std::is_same_v<BuilderType, MegaCircuitBuilder>) {
P0 += other.P0 * recursion_separator;
P1 += other.P1 * recursion_separator;
} else {
// Save gates using short scalars. We don't apply `bn254_endo_batch_mul` to the vector {1,
// recursion_separator} directly to avoid edge cases.
typename Curve::Group point_to_aggregate = other.P0.template scalar_mul<128>(recursion_separator);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this costs ~ 72K per aggregate call

P0 += point_to_aggregate;
point_to_aggregate = other.P1.template scalar_mul<128>(recursion_separator);
P1 += point_to_aggregate;
}
}

template <typename BuilderType = void>
void aggregate(std::array<typename Curve::Group, 2> const& other, typename Curve::ScalarField recursion_separator)
{
P0 += other[0] * recursion_separator;
P1 += other[1] * recursion_separator;
if constexpr (std::is_same_v<BuilderType, MegaCircuitBuilder>) {
P0 += other[0] * recursion_separator;
P1 += other[1] * recursion_separator;
} else {
// Save gates using short scalars. We don't apply `bn254_endo_batch_mul` to the vector {1,
// recursion_separator} directly to avoid edge cases.
typename Curve::Group point_to_aggregate = other[0].template scalar_mul<128>(recursion_separator);
P0 += point_to_aggregate;
point_to_aggregate = other[1].template scalar_mul<128>(recursion_separator);
P1 += point_to_aggregate;
}
}

PairingPointAccumulatorIndices get_witness_indices()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {
result.y.assert_is_in_field();
return result;
}
template <size_t max_num_bits> element scalar_mul(const Fr& scalar) const;

element reduce() const
{
Expand Down Expand Up @@ -525,7 +526,10 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {
num_fives = num_points / 5;
num_sixes = 0;
// size-6 table is expensive and only benefits us if creating them reduces the number of total tables
if (num_fives * 5 == (num_points - 1)) {
if (num_points == 1) {
Copy link
Contributor Author

@iakovenkos iakovenkos Feb 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle the case of a single point to re-use this function in the * operator and in scalar_mul method

Probably there are other edge cases

num_fives = 0;
num_sixes = 0;
} else if (num_fives * 5 == (num_points - 1)) {
num_fives -= 1;
num_sixes = 1;
} else if (num_fives * 5 == (num_points - 2) && num_fives >= 2) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,74 @@ template <typename TestType> class stdlib_biggroup : public testing::Test {
EXPECT_CIRCUIT_CORRECTNESS(builder);
}

static void test_short_scalar_mul()
{
Builder builder;
size_t num_repetitions = 1;
for (size_t i = 0; i < num_repetitions; ++i) {
affine_element input(element::random_element());
// Get 128-bit scalar
uint256_t scalar_raw = fr::random_element();
scalar_raw.data[2] = 0ULL;
scalar_raw.data[3] = 0ULL;
fr scalar = fr(scalar_raw);
// Add skew
if (uint256_t(scalar).get_bit(0)) {
scalar -= fr(1);
}
element_ct P = element_ct::from_witness(&builder, input);
scalar_ct x = scalar_ct::from_witness(&builder, scalar);

// Set input tags
x.set_origin_tag(challenge_origin_tag);
P.set_origin_tag(submitted_value_origin_tag);

std::cerr << "gates before mul " << builder.get_estimated_num_finalized_gates() << std::endl;
element_ct c = P.template scalar_mul<128>(x);
std::cerr << "builder aftr mul " << builder.get_estimated_num_finalized_gates() << std::endl;
affine_element c_expected(element(input) * scalar);

// Check the result of the multiplication has a tag that's the union of inputs' tags
EXPECT_EQ(c.get_origin_tag(), first_two_merged_tag);
fq c_x_result(c.x.get_value().lo);
fq c_y_result(c.y.get_value().lo);

EXPECT_EQ(c_x_result, c_expected.x);

EXPECT_EQ(c_y_result, c_expected.y);
}

EXPECT_CIRCUIT_CORRECTNESS(builder);
}

static void test_short_scalar_mul_infinity()
{
Builder builder;
element input = element::infinity();

fr scalar(fr(6));
if (uint256_t(scalar).get_bit(0)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens without the skew?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's an artifact of some edge case testing, removed it.

scalar -= fr(1); // make sure to add skew
}
element_ct P = element_ct::from_witness(&builder, input);
scalar_ct x = scalar_ct::from_witness(&builder, scalar);

// Set input tags
x.set_origin_tag(challenge_origin_tag);
P.set_origin_tag(submitted_value_origin_tag);

std::cerr << "gates before mul " << builder.get_estimated_num_finalized_gates() << std::endl;
element_ct c = P.template scalar_mul<128>(x);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can check the number of gates

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And since you've created the arbitrary size scalar mul, could you create a test for every number of bits?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added the check
added the tests for even numbers of bits + assert statements in compute nafs and bn254_endo_batch_mul. I don't see why there was an assertion on num bits >= 128. I'll add more tests for bn254_endo_batch_mul in a follow-up, maybe something fails when there are more small points.

std::cerr << "builder aftr mul " << builder.get_estimated_num_finalized_gates() << std::endl;

// Check the result of the multiplication has a tag that's the union of inputs' tags
EXPECT_EQ(c.get_origin_tag(), first_two_merged_tag);

EXPECT_EQ(c.is_point_at_infinity().get_value(), true);

EXPECT_CIRCUIT_CORRECTNESS(builder);
}

static void test_twin_mul()
{
Builder builder;
Expand Down Expand Up @@ -950,25 +1018,36 @@ template <typename TestType> class stdlib_biggroup : public testing::Test {
static void test_compute_naf()
{
Builder builder = Builder();
size_t num_repetitions(32);
for (size_t i = 0; i < num_repetitions; i++) {
fr scalar_val = fr::random_element();
scalar_ct scalar = scalar_ct::from_witness(&builder, scalar_val);
// Set tag for scalar
scalar.set_origin_tag(submitted_value_origin_tag);
auto naf = element_ct::compute_naf(scalar);

for (const auto& bit : naf) {
// Check that the tag is propagated to bits
EXPECT_EQ(bit.get_origin_tag(), submitted_value_origin_tag);
std::vector<size_t> bit_lengths = { 254, 128 };
Copy link
Contributor Author

@iakovenkos iakovenkos Feb 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need any other lengths? maybe smth like 136?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think 136 makes sense here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expanded the test range anyway


for (auto max_num_bits : bit_lengths) {
size_t num_repetitions(32);
for (size_t i = 0; i < num_repetitions; i++) {
fr scalar_val = fr::random_element();
if (max_num_bits == 128) {
uint256_t scalar_raw = fr::random_element();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can get a random uint128 from the random engine

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switched to taking uint256_t modulo 256-length

scalar_raw.data[2] = 0ULL;
scalar_raw.data[3] = 0ULL;
scalar_val = fr(scalar_raw);
}
scalar_ct scalar = scalar_ct::from_witness(&builder, scalar_val);
// Set tag for scalar
scalar.set_origin_tag(submitted_value_origin_tag);
auto naf = element_ct::compute_naf(scalar, max_num_bits);

for (const auto& bit : naf) {
// Check that the tag is propagated to bits
EXPECT_EQ(bit.get_origin_tag(), submitted_value_origin_tag);
}
// scalar = -naf[254] + \sum_{i=0}^{253}(1-2*naf[i]) 2^{253-i}
fr reconstructed_val(0);
for (size_t i = 0; i < max_num_bits; i++) {
reconstructed_val +=
(fr(1) - fr(2) * fr(naf[i].witness_bool)) * fr(uint256_t(1) << (max_num_bits - 1 - i));
};
reconstructed_val -= fr(naf[max_num_bits].witness_bool);
EXPECT_EQ(scalar_val, reconstructed_val);
}
// scalar = -naf[254] + \sum_{i=0}^{253}(1-2*naf[i]) 2^{253-i}
fr reconstructed_val(0);
for (size_t i = 0; i < 254; i++) {
reconstructed_val += (fr(1) - fr(2) * fr(naf[i].witness_bool)) * fr(uint256_t(1) << (253 - i));
};
reconstructed_val -= fr(naf[254].witness_bool);
EXPECT_EQ(scalar_val, reconstructed_val);
}
EXPECT_CIRCUIT_CORRECTNESS(builder);
}
Expand Down Expand Up @@ -1614,6 +1693,25 @@ HEAVY_TYPED_TEST(stdlib_biggroup, mul)
{
TestFixture::test_mul();
}

HEAVY_TYPED_TEST(stdlib_biggroup, short_scalar_mul)
{
if constexpr (HasGoblinBuilder<TypeParam>) {
GTEST_SKIP();
} else {
TestFixture::test_short_scalar_mul();
}
}

HEAVY_TYPED_TEST(stdlib_biggroup, short_scalar_mul_infinity)
{
if constexpr (HasGoblinBuilder<TypeParam>) {
GTEST_SKIP();
} else {
TestFixture::test_short_scalar_mul_infinity();
}
}

HEAVY_TYPED_TEST(stdlib_biggroup, twin_mul)
{
if constexpr (HasGoblinBuilder<TypeParam>) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -836,14 +836,24 @@ element<C, Fq, Fr, G> element<C, Fq, Fr, G>::batch_mul(const std::vector<element
return accumulator;
}
}
/**
* Implements scalar multiplication operator.
*/
template <typename C, class Fq, class Fr, class G>
element<C, Fq, Fr, G> element<C, Fq, Fr, G>::operator*(const Fr& scalar) const
{
// Use `scalar_mul` method without specifying the length of `scalar`.
return scalar_mul<0>(scalar);
}

/**
* Implements scalar multiplication.
* Implements scalar multiplication that supports short scalars.
*
* For multiple scalar multiplication use one of the `batch_mul` methods to save gates.
**/
template <typename C, class Fq, class Fr, class G>
element<C, Fq, Fr, G> element<C, Fq, Fr, G>::operator*(const Fr& scalar) const
template <size_t max_num_bits>
element<C, Fq, Fr, G> element<C, Fq, Fr, G>::scalar_mul(const Fr& scalar) const
{
/**
*
Expand All @@ -868,27 +878,31 @@ element<C, Fq, Fr, G> element<C, Fq, Fr, G>::operator*(const Fr& scalar) const
* specifics.
*
**/
OriginTag tag{};
tag = OriginTag(tag, OriginTag(this->get_origin_tag(), scalar.get_origin_tag()));

constexpr uint64_t num_rounds = Fr::modulus.get_msb() + 1;

std::vector<bool_ct> naf_entries = compute_naf(scalar);
bool_ct is_point_at_infinity = this->is_point_at_infinity();

const auto offset_generators = compute_offset_generators(num_rounds);
const size_t num_rounds = (max_num_bits == 0) ? Fr::modulus.get_msb() + 1 : max_num_bits;

element accumulator = *this + offset_generators.first;
element result;
if constexpr (max_num_bits != 0) {
// The case of short scalars
result = element::bn254_endo_batch_mul({}, {}, { *this }, { scalar }, num_rounds);
} else {
// The case of arbitrary length scalars
result = element::bn254_endo_batch_mul({ *this }, { scalar }, {}, {}, num_rounds);
};

for (size_t i = 1; i < num_rounds; ++i) {
bool_ct predicate = naf_entries[i];
bigfield y_test = y.conditional_negate(predicate);
element to_add(x, y_test);
accumulator = accumulator.montgomery_ladder(to_add);
}
// Handle point at infinity
result.x = Fq::conditional_assign(is_point_at_infinity, x, result.x);
result.y = Fq::conditional_assign(is_point_at_infinity, y, result.y);

element skew_output = accumulator - (*this);
result.set_point_at_infinity(is_point_at_infinity);

Fq out_x = accumulator.x.conditional_select(skew_output.x, naf_entries[num_rounds]);
Fq out_y = accumulator.y.conditional_select(skew_output.y, naf_entries[num_rounds]);
// Propagate the origin tag
result.set_origin_tag(tag);

return element(out_x, out_y) - element(offset_generators.second);
return result;
}
} // namespace bb::stdlib::element_default
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ std::vector<field_t<C>> element<C, Fq, Fr, G>::compute_wnaf(const Fr& scalar)
// updates multiplicative constants without computing new witnesses. This ensures the low accumulator will not
// underflow
//
// Once we hvae reconstructed an Fr element out of our accumulators,
// Once we have reconstructed an Fr element out of our accumulators,
// we ALSO construct an Fr element from the constant offset terms we left out
// We then subtract off the constant term and call `Fr::assert_is_in_field` to reduce the value modulo
// Fr::modulus
Expand Down Expand Up @@ -576,9 +576,23 @@ std::vector<bool_t<C>> element<C, Fq, Fr, G>::compute_naf(const Fr& scalar, cons
}
return std::make_pair(positive_accumulator, negative_accumulator);
};
const size_t midpoint = num_rounds - Fr::NUM_LIMB_BITS * 2;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

although it accepts the length as an arg, it wasn't used to determine the midpoint

auto hi_accumulators = reconstruct_half_naf(&naf_entries[0], midpoint);
auto lo_accumulators = reconstruct_half_naf(&naf_entries[midpoint], num_rounds - midpoint);
const size_t midpoint =
(num_rounds > Fr::NUM_LIMB_BITS * 2) ? num_rounds - Fr::NUM_LIMB_BITS * 2 : num_rounds / 2;

std::pair<field_t<C>, field_t<C>> hi_accumulators;
std::pair<field_t<C>, field_t<C>> lo_accumulators;

if (num_rounds > Fr::NUM_LIMB_BITS * 2) {
hi_accumulators = reconstruct_half_naf(&naf_entries[0], midpoint);
lo_accumulators = reconstruct_half_naf(&naf_entries[midpoint], num_rounds - midpoint);

} else {
// If the number of rounds is smaller than Fr::NUM_LIMB_BITS, the high bits of the resulting Fr element are
// 0.
const field_t<C> zero = field_t<C>::from_witness_index(ctx, 0);
lo_accumulators = reconstruct_half_naf(&naf_entries[0], num_rounds);
hi_accumulators = std::make_pair(zero, zero);
}

lo_accumulators.second = lo_accumulators.second + field_t<C>(naf_entries[num_rounds]);

Expand Down