-
Notifications
You must be signed in to change notification settings - Fork 311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
chore: aggregate with short scalars in UH Recursion #11478
Changes from 13 commits
22b4e85
aef633b
57f8665
6f2490d
35436d8
34d095f
793846a
026c942
95ee6e2
e27b335
b6c3bb3
19ae624
cf4942e
4350a97
d9dc496
3cda849
61d9ac2
d2d2999
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,17 +20,36 @@ template <typename Curve> struct aggregation_state { | |
{ | ||
return P0 == other.P0 && P1 == other.P1; | ||
}; | ||
|
||
template <typename BuilderType = void> | ||
void aggregate(aggregation_state const& other, typename Curve::ScalarField recursion_separator) | ||
{ | ||
P0 += other.P0 * recursion_separator; | ||
P1 += other.P1 * recursion_separator; | ||
if constexpr (std::is_same_v<BuilderType, MegaCircuitBuilder>) { | ||
P0 += other.P0 * recursion_separator; | ||
P1 += other.P1 * recursion_separator; | ||
} else { | ||
// Save gates using short scalars. We don't apply `bn254_endo_batch_mul` to the vector {1, | ||
// recursion_separator} directly to avoid edge cases. | ||
typename Curve::Group point_to_aggregate = other.P0.template scalar_mul<128>(recursion_separator); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this costs ~ 72K per aggregate call |
||
P0 += point_to_aggregate; | ||
point_to_aggregate = other.P1.template scalar_mul<128>(recursion_separator); | ||
P1 += point_to_aggregate; | ||
} | ||
} | ||
|
||
template <typename BuilderType = void> | ||
void aggregate(std::array<typename Curve::Group, 2> const& other, typename Curve::ScalarField recursion_separator) | ||
{ | ||
P0 += other[0] * recursion_separator; | ||
P1 += other[1] * recursion_separator; | ||
if constexpr (std::is_same_v<BuilderType, MegaCircuitBuilder>) { | ||
P0 += other[0] * recursion_separator; | ||
P1 += other[1] * recursion_separator; | ||
} else { | ||
// Save gates using short scalars. We don't apply `bn254_endo_batch_mul` to the vector {1, | ||
// recursion_separator} directly to avoid edge cases. | ||
typename Curve::Group point_to_aggregate = other[0].template scalar_mul<128>(recursion_separator); | ||
P0 += point_to_aggregate; | ||
point_to_aggregate = other[1].template scalar_mul<128>(recursion_separator); | ||
P1 += point_to_aggregate; | ||
} | ||
} | ||
|
||
PairingPointAccumulatorIndices get_witness_indices() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -143,6 +143,7 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element { | |
result.y.assert_is_in_field(); | ||
return result; | ||
} | ||
template <size_t max_num_bits> element scalar_mul(const Fr& scalar) const; | ||
|
||
element reduce() const | ||
{ | ||
|
@@ -525,7 +526,10 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element { | |
num_fives = num_points / 5; | ||
num_sixes = 0; | ||
// size-6 table is expensive and only benefits us if creating them reduces the number of total tables | ||
if (num_fives * 5 == (num_points - 1)) { | ||
if (num_points == 1) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Handle the case of a single point to re-use this function in the * operator and in Probably there are other edge cases |
||
num_fives = 0; | ||
num_sixes = 0; | ||
} else if (num_fives * 5 == (num_points - 1)) { | ||
num_fives -= 1; | ||
num_sixes = 1; | ||
} else if (num_fives * 5 == (num_points - 2) && num_fives >= 2) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -414,6 +414,74 @@ template <typename TestType> class stdlib_biggroup : public testing::Test { | |
EXPECT_CIRCUIT_CORRECTNESS(builder); | ||
} | ||
|
||
static void test_short_scalar_mul() | ||
{ | ||
Builder builder; | ||
size_t num_repetitions = 1; | ||
for (size_t i = 0; i < num_repetitions; ++i) { | ||
affine_element input(element::random_element()); | ||
// Get 128-bit scalar | ||
uint256_t scalar_raw = fr::random_element(); | ||
scalar_raw.data[2] = 0ULL; | ||
scalar_raw.data[3] = 0ULL; | ||
fr scalar = fr(scalar_raw); | ||
// Add skew | ||
if (uint256_t(scalar).get_bit(0)) { | ||
scalar -= fr(1); | ||
} | ||
element_ct P = element_ct::from_witness(&builder, input); | ||
scalar_ct x = scalar_ct::from_witness(&builder, scalar); | ||
|
||
// Set input tags | ||
x.set_origin_tag(challenge_origin_tag); | ||
P.set_origin_tag(submitted_value_origin_tag); | ||
|
||
std::cerr << "gates before mul " << builder.get_estimated_num_finalized_gates() << std::endl; | ||
element_ct c = P.template scalar_mul<128>(x); | ||
std::cerr << "builder aftr mul " << builder.get_estimated_num_finalized_gates() << std::endl; | ||
affine_element c_expected(element(input) * scalar); | ||
|
||
// Check the result of the multiplication has a tag that's the union of inputs' tags | ||
EXPECT_EQ(c.get_origin_tag(), first_two_merged_tag); | ||
fq c_x_result(c.x.get_value().lo); | ||
fq c_y_result(c.y.get_value().lo); | ||
|
||
EXPECT_EQ(c_x_result, c_expected.x); | ||
|
||
EXPECT_EQ(c_y_result, c_expected.y); | ||
} | ||
|
||
EXPECT_CIRCUIT_CORRECTNESS(builder); | ||
} | ||
|
||
static void test_short_scalar_mul_infinity() | ||
{ | ||
Builder builder; | ||
element input = element::infinity(); | ||
|
||
fr scalar(fr(6)); | ||
if (uint256_t(scalar).get_bit(0)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens without the skew? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's an artifact of some edge case testing, removed it. |
||
scalar -= fr(1); // make sure to add skew | ||
} | ||
element_ct P = element_ct::from_witness(&builder, input); | ||
scalar_ct x = scalar_ct::from_witness(&builder, scalar); | ||
|
||
// Set input tags | ||
x.set_origin_tag(challenge_origin_tag); | ||
P.set_origin_tag(submitted_value_origin_tag); | ||
|
||
std::cerr << "gates before mul " << builder.get_estimated_num_finalized_gates() << std::endl; | ||
element_ct c = P.template scalar_mul<128>(x); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can check the number of gates There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And since you've created the arbitrary size scalar mul, could you create a test for every number of bits? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added the check |
||
std::cerr << "builder aftr mul " << builder.get_estimated_num_finalized_gates() << std::endl; | ||
|
||
// Check the result of the multiplication has a tag that's the union of inputs' tags | ||
EXPECT_EQ(c.get_origin_tag(), first_two_merged_tag); | ||
|
||
EXPECT_EQ(c.is_point_at_infinity().get_value(), true); | ||
|
||
EXPECT_CIRCUIT_CORRECTNESS(builder); | ||
} | ||
|
||
static void test_twin_mul() | ||
{ | ||
Builder builder; | ||
|
@@ -950,25 +1018,36 @@ template <typename TestType> class stdlib_biggroup : public testing::Test { | |
static void test_compute_naf() | ||
{ | ||
Builder builder = Builder(); | ||
size_t num_repetitions(32); | ||
for (size_t i = 0; i < num_repetitions; i++) { | ||
fr scalar_val = fr::random_element(); | ||
scalar_ct scalar = scalar_ct::from_witness(&builder, scalar_val); | ||
// Set tag for scalar | ||
scalar.set_origin_tag(submitted_value_origin_tag); | ||
auto naf = element_ct::compute_naf(scalar); | ||
|
||
for (const auto& bit : naf) { | ||
// Check that the tag is propagated to bits | ||
EXPECT_EQ(bit.get_origin_tag(), submitted_value_origin_tag); | ||
std::vector<size_t> bit_lengths = { 254, 128 }; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need any other lengths? maybe smth like 136? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think 136 makes sense here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I expanded the test range anyway |
||
|
||
for (auto max_num_bits : bit_lengths) { | ||
size_t num_repetitions(32); | ||
for (size_t i = 0; i < num_repetitions; i++) { | ||
fr scalar_val = fr::random_element(); | ||
if (max_num_bits == 128) { | ||
uint256_t scalar_raw = fr::random_element(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can get a random uint128 from the random engine There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. switched to taking uint256_t modulo 256-length |
||
scalar_raw.data[2] = 0ULL; | ||
scalar_raw.data[3] = 0ULL; | ||
scalar_val = fr(scalar_raw); | ||
} | ||
scalar_ct scalar = scalar_ct::from_witness(&builder, scalar_val); | ||
// Set tag for scalar | ||
scalar.set_origin_tag(submitted_value_origin_tag); | ||
auto naf = element_ct::compute_naf(scalar, max_num_bits); | ||
|
||
for (const auto& bit : naf) { | ||
// Check that the tag is propagated to bits | ||
EXPECT_EQ(bit.get_origin_tag(), submitted_value_origin_tag); | ||
} | ||
// scalar = -naf[254] + \sum_{i=0}^{253}(1-2*naf[i]) 2^{253-i} | ||
fr reconstructed_val(0); | ||
for (size_t i = 0; i < max_num_bits; i++) { | ||
reconstructed_val += | ||
(fr(1) - fr(2) * fr(naf[i].witness_bool)) * fr(uint256_t(1) << (max_num_bits - 1 - i)); | ||
}; | ||
reconstructed_val -= fr(naf[max_num_bits].witness_bool); | ||
EXPECT_EQ(scalar_val, reconstructed_val); | ||
} | ||
// scalar = -naf[254] + \sum_{i=0}^{253}(1-2*naf[i]) 2^{253-i} | ||
fr reconstructed_val(0); | ||
for (size_t i = 0; i < 254; i++) { | ||
reconstructed_val += (fr(1) - fr(2) * fr(naf[i].witness_bool)) * fr(uint256_t(1) << (253 - i)); | ||
}; | ||
reconstructed_val -= fr(naf[254].witness_bool); | ||
EXPECT_EQ(scalar_val, reconstructed_val); | ||
} | ||
EXPECT_CIRCUIT_CORRECTNESS(builder); | ||
} | ||
|
@@ -1614,6 +1693,25 @@ HEAVY_TYPED_TEST(stdlib_biggroup, mul) | |
{ | ||
TestFixture::test_mul(); | ||
} | ||
|
||
HEAVY_TYPED_TEST(stdlib_biggroup, short_scalar_mul) | ||
{ | ||
if constexpr (HasGoblinBuilder<TypeParam>) { | ||
GTEST_SKIP(); | ||
} else { | ||
TestFixture::test_short_scalar_mul(); | ||
} | ||
} | ||
|
||
HEAVY_TYPED_TEST(stdlib_biggroup, short_scalar_mul_infinity) | ||
{ | ||
if constexpr (HasGoblinBuilder<TypeParam>) { | ||
GTEST_SKIP(); | ||
} else { | ||
TestFixture::test_short_scalar_mul_infinity(); | ||
} | ||
} | ||
|
||
HEAVY_TYPED_TEST(stdlib_biggroup, twin_mul) | ||
{ | ||
if constexpr (HasGoblinBuilder<TypeParam>) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -441,7 +441,7 @@ std::vector<field_t<C>> element<C, Fq, Fr, G>::compute_wnaf(const Fr& scalar) | |
// updates multiplicative constants without computing new witnesses. This ensures the low accumulator will not | ||
// underflow | ||
// | ||
// Once we hvae reconstructed an Fr element out of our accumulators, | ||
// Once we have reconstructed an Fr element out of our accumulators, | ||
// we ALSO construct an Fr element from the constant offset terms we left out | ||
// We then subtract off the constant term and call `Fr::assert_is_in_field` to reduce the value modulo | ||
// Fr::modulus | ||
|
@@ -576,9 +576,23 @@ std::vector<bool_t<C>> element<C, Fq, Fr, G>::compute_naf(const Fr& scalar, cons | |
} | ||
return std::make_pair(positive_accumulator, negative_accumulator); | ||
}; | ||
const size_t midpoint = num_rounds - Fr::NUM_LIMB_BITS * 2; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. although it accepts the length as an arg, it wasn't used to determine the midpoint |
||
auto hi_accumulators = reconstruct_half_naf(&naf_entries[0], midpoint); | ||
auto lo_accumulators = reconstruct_half_naf(&naf_entries[midpoint], num_rounds - midpoint); | ||
const size_t midpoint = | ||
(num_rounds > Fr::NUM_LIMB_BITS * 2) ? num_rounds - Fr::NUM_LIMB_BITS * 2 : num_rounds / 2; | ||
|
||
std::pair<field_t<C>, field_t<C>> hi_accumulators; | ||
std::pair<field_t<C>, field_t<C>> lo_accumulators; | ||
|
||
if (num_rounds > Fr::NUM_LIMB_BITS * 2) { | ||
hi_accumulators = reconstruct_half_naf(&naf_entries[0], midpoint); | ||
lo_accumulators = reconstruct_half_naf(&naf_entries[midpoint], num_rounds - midpoint); | ||
|
||
} else { | ||
// If the number of rounds is smaller than Fr::NUM_LIMB_BITS, the high bits of the resulting Fr element are | ||
// 0. | ||
const field_t<C> zero = field_t<C>::from_witness_index(ctx, 0); | ||
lo_accumulators = reconstruct_half_naf(&naf_entries[0], num_rounds); | ||
hi_accumulators = std::make_pair(zero, zero); | ||
} | ||
|
||
lo_accumulators.second = lo_accumulators.second + field_t<C>(naf_entries[num_rounds]); | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this costs >140K per aggregate call