From 15808f6bdd808781e4df599e31d900c8b4a443cc Mon Sep 17 00:00:00 2001 From: Michael Connor Date: Tue, 27 Sep 2022 12:21:26 +0100 Subject: [PATCH] Revert "Bug fixes for issues found during fuzzing bigfield and safeuint as well as over cases (#1279)" (#1527) This reverts commit 88c52fd6e3833fddeb8cc7adfdb70537a57e4e8e. --- src/aztec/ecc/curves/bn254/fq.test.cpp | 26 - .../ecc/curves/secp256k1/secp256k1.test.cpp | 12 - src/aztec/ecc/fields/asm_macros.hpp | 20 +- src/aztec/ecc/fields/field.hpp | 1 - src/aztec/ecc/fields/field_impl.hpp | 10 +- src/aztec/ecc/groups/affine_element.test.cpp | 13 +- src/aztec/ecc/groups/affine_element_impl.hpp | 7 - src/aztec/numeric/uint256/uint256.hpp | 3 +- .../stdlib/primitives/bigfield/bigfield.hpp | 6 - .../primitives/bigfield/bigfield.test.cpp | 36 +- .../primitives/bigfield/bigfield_impl.hpp | 452 ++++++------------ src/aztec/stdlib/primitives/field/field.cpp | 13 +- .../stdlib/primitives/safe_uint/safe_uint.cpp | 21 +- .../stdlib/primitives/safe_uint/safe_uint.hpp | 7 +- .../primitives/safe_uint/safe_uint.test.cpp | 12 +- 15 files changed, 190 insertions(+), 449 deletions(-) diff --git a/src/aztec/ecc/curves/bn254/fq.test.cpp b/src/aztec/ecc/curves/bn254/fq.test.cpp index ead55d8d1f2..3e4c0555ca1 100644 --- a/src/aztec/ecc/curves/bn254/fq.test.cpp +++ b/src/aztec/ecc/curves/bn254/fq.test.cpp @@ -471,29 +471,3 @@ TEST(fq, pow_regression_check) EXPECT_EQ(zero.pow(uint256_t(0)), one); } // 438268ca91d42ad f1e7025a7b654e1f f8d9d72e0438b995 8c422ec208ac8a6e - -TEST(fq, sqr_regression) -{ - uint256_t values[] = { uint256_t(0xbdf876654b0ade1b, 0x2c3a66c64569f338, 0x2cd8bf2ec1fe55a3, 0x11c0ea9ee5693ede), - uint256_t(0x551b14ec34f2151c, 0x62e472ed83a2891e, 0xf208d5e5c9b5b3fb, 0x14315aeaf6027d8c), - uint256_t(0xad39959ae8013750, 0x7f1d2c709ab84cbb, 0x408028b80a60c2f1, 0x1dcd116fc26f856e), - uint256_t(0x95e967d30dcce9ce, 0x56139274241d2ea1, 0x85b19c1c616ec456, 0x1f1780cf9bf045b4), - uint256_t(0xbe841c861d8eb80e, 0xc5980d67a21386c0, 0x5fd1f1afecddeeb5, 0x24dbb8c1baea0250), - uint256_t(0x3ae4b3a27f05d6e3, 0xc5f6785b12df8d29, 0xc3a6c5f095103046, 0xd6b94cb2cc1fd4b), - uint256_t(0xc003c71932a6ced5, 0x6302a413f68e26e9, 0x2ed4a9b64d69fad, 0xfe61ffab1ae227d) }; - for (auto& value : values) { - fq element(value); - EXPECT_EQ(element.sqr(), element * element); - } -} - -TEST(fq, neg_and_self_neg_0_cmp_regression) -{ - fq a = 0; - fq a_neg = -a; - EXPECT_EQ((a == a_neg), true); - a = 0; - a_neg = 0; - a_neg.self_neg(); - EXPECT_EQ((a == a_neg), true); -} \ No newline at end of file diff --git a/src/aztec/ecc/curves/secp256k1/secp256k1.test.cpp b/src/aztec/ecc/curves/secp256k1/secp256k1.test.cpp index c05b356f710..0491eb737e3 100644 --- a/src/aztec/ecc/curves/secp256k1/secp256k1.test.cpp +++ b/src/aztec/ecc/curves/secp256k1/secp256k1.test.cpp @@ -494,16 +494,4 @@ TEST(secp256k1, derive_generators) } } */ - -TEST(secp256k1, neg_and_self_neg_0_cmp_regression) -{ - secp256k1::fq a = 0; - secp256k1::fq a_neg = -a; - EXPECT_EQ((a == a_neg), true); - a = 0; - a_neg = 0; - a_neg.self_neg(); - EXPECT_EQ((a == a_neg), true); -} - } // namespace test_secp256k1 \ No newline at end of file diff --git a/src/aztec/ecc/fields/asm_macros.hpp b/src/aztec/ecc/fields/asm_macros.hpp index bb657b0f030..cfe17a46356 100644 --- a/src/aztec/ecc/fields/asm_macros.hpp +++ b/src/aztec/ecc/fields/asm_macros.hpp @@ -162,7 +162,7 @@ "mulxq %[modulus_1], %%rdi, %%rcx \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \ "adcq %%rcx, %%r10 \n\t" /* r[2] += t[3] + flag_c */ \ "adcq $0, %%r11 \n\t" /* r[4] += flag_c */ \ - /* Partial fix "adcq $0, %%r12 \n\t"*/ /* r[4] += flag_c */ \ +/* Partial fix "adcq $0, %%r12 \n\t"*/ /* r[4] += flag_c */ \ "addq %%rdi, %%r9 \n\t" /* r[1] += t[2] */ \ "mulxq %[modulus_2], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[3] * k) */ \ "mulxq %[modulus_3], %%r8, %%rdx \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \ @@ -540,7 +540,6 @@ "adcxq %%rcx, %%r13 \n\t" /* r[5] += t[4] + flag_o */ \ "adoxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_o */ \ "adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_c */ \ - "adoxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \ \ /* double result registers */ \ "adoxq %%r9, %%r9 \n\t" /* r[1] = 2r[1] */ \ @@ -574,12 +573,10 @@ "mulxq %[modulus_0], %%rdi, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \ "adoxq %%rdi, %%r8 \n\t" /* r[0] += t[0] (%r8 now free) */ \ "mulxq %[modulus_3], %%r8, %%rdi \n\t" /* (t[2], t[3]) <- (modulus[2] * k) */ \ - "adcxq %%rdi, %%r12 \n\t" /* r[4] += t[3] + flag_c */ \ - "adoxq %%rcx, %%r9 \n\t" /* r[1] += t[1] + flag_o */ \ - "adcxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_c */ \ - "adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_c */ \ + "adcxq %%rdi, %%r12 \n\t" /* r[4] += t[3] + flag_o */ \ + "adoxq %%rcx, %%r9 \n\t" /* r[1] += t[1] + flag_c */ \ + "adcxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_o */ \ "mulxq %[modulus_1], %%rdi, %%rcx \n\t" /* (t[2], t[3]) <- (modulus[1] * k) */ \ - "adcxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_c */ \ "adoxq %%rcx, %%r10 \n\t" /* r[2] += t[3] + flag_o */ \ "adcxq %%rdi, %%r9 \n\t" /* r[1] += t[2] */ \ "adoxq %%r8, %%r11 \n\t" /* r[3] += t[2] + flag_o */ \ @@ -597,9 +594,6 @@ "adoxq %%rcx, %%r13 \n\t" /* r[5] += t[3] + flag_o */ \ "adcxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_c */ \ "adoxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \ - "adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_c */ \ - "adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_o */ \ - "adcxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_c */ \ "mulxq %[modulus_0], %%r8, %%rcx \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \ "adcxq %%r8, %%r9 \n\t" /* r[1] += t[0] (%r9 now free) */ \ "adoxq %%rcx, %%r10 \n\t" /* r[2] += t[1] + flag_c */ \ @@ -620,14 +614,12 @@ "adcxq %%r8, %%r13 \n\t" /* r[5] += t[2] + flag_c */ \ "adoxq %%r9, %%r14 \n\t" /* r[6] += t[3] + flag_c */ \ "adcxq %[zero_reference], %%r14 \n\t" /* r[6] += flag_o */ \ - "adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_o */ \ - "adcxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_c */ \ + "adoxq %[zero_reference], %%r15 \n\t" /* r[7] += flag_c */ \ "mulxq %[modulus_0], %%r8, %%r9 \n\t" /* (t[0], t[1]) <- (modulus[0] * k) */ \ "adcxq %%r8, %%r10 \n\t" /* r[2] += t[0] (%r10 now free) */ \ "adoxq %%r9, %%r11 \n\t" /* r[3] += t[1] + flag_c */ \ "adcxq %%rdi, %%r11 \n\t" /* r[3] += t[2] */ \ - "adoxq %[zero_reference], %%r12 \n\t" /* r[4] += flag_o */ \ - "adoxq %[zero_reference], %%r13 \n\t" /* r[5] += flag_o */ \ + "adoxq %[zero_reference], %%r12 \n\t" /* r[4] += flag_c */ \ \ /* perform modular reduction: r[3] */ \ "movq %%r11, %%rdx \n\t" /* move r11 into %rdx */ \ diff --git a/src/aztec/ecc/fields/field.hpp b/src/aztec/ecc/fields/field.hpp index 973ace6af0a..09101a87259 100644 --- a/src/aztec/ecc/fields/field.hpp +++ b/src/aztec/ecc/fields/field.hpp @@ -23,7 +23,6 @@ namespace barretenberg { template struct alignas(32) field { public: - // We don't initialize data by default since we'd lose a lot of time on pointless initializations. field() noexcept {} constexpr field(const uint256_t& input) noexcept diff --git a/src/aztec/ecc/fields/field_impl.hpp b/src/aztec/ecc/fields/field_impl.hpp index 7331206ebca..71dd9cba122 100644 --- a/src/aztec/ecc/fields/field_impl.hpp +++ b/src/aztec/ecc/fields/field_impl.hpp @@ -153,7 +153,7 @@ template constexpr field field::operator-() const noexcept return p - *this; // modulus - *this; } constexpr field p{ twice_modulus.data[0], twice_modulus.data[1], twice_modulus.data[2], twice_modulus.data[3] }; - return (p - *this).reduce_once(); // modulus - *this; + return p - *this; // modulus - *this; } template constexpr field field::operator-=(const field& other) noexcept @@ -179,7 +179,7 @@ template constexpr void field::self_neg() noexcept *this = p - *this; } else { constexpr field p{ twice_modulus.data[0], twice_modulus.data[1], twice_modulus.data[2], twice_modulus.data[3] }; - *this = (p - *this).reduce_once(); + *this = p - *this; } } @@ -236,9 +236,9 @@ template constexpr field field::to_montgomery_form() const noexc constexpr field r_squared{ T::r_squared_0, T::r_squared_1, T::r_squared_2, T::r_squared_3 }; field result = *this; - result.self_reduce_once(); - result.self_reduce_once(); - result.self_reduce_once(); + result.reduce_once(); + result.reduce_once(); + result.reduce_once(); return (result * r_squared).reduce_once(); } diff --git a/src/aztec/ecc/groups/affine_element.test.cpp b/src/aztec/ecc/groups/affine_element.test.cpp index a7ad49abc13..eecde30ac63 100644 --- a/src/aztec/ecc/groups/affine_element.test.cpp +++ b/src/aztec/ecc/groups/affine_element.test.cpp @@ -1,5 +1,4 @@ #include -#include #include #include #include @@ -29,21 +28,11 @@ TEST(affine_element, read_write_buffer) // Regression test to ensure that the point at infinity is not equal to its coordinate-wise reduction, which may lie // on the curve, depending on the y-coordinate. -TEST(affine_element, infinity_equality_regression) +TEST(affine_element, infinity_regression) { g1::affine_element P; P.self_set_infinity(); g1::affine_element R(0, P.y); ASSERT_FALSE(P == R); } - -// Regression test to ensure that the point at infinity is not equal to its coordinate-wise reduction, which may lie -// on the curve, depending on the y-coordinate. -TEST(affine_element, infinity_ordering_regression) -{ - secp256k1::g1::affine_element P(0, 1), Q(0, 1); - - P.self_set_infinity(); - EXPECT_NE(P < Q, Q < P); -} } // namespace test_affine_element \ No newline at end of file diff --git a/src/aztec/ecc/groups/affine_element_impl.hpp b/src/aztec/ecc/groups/affine_element_impl.hpp index 20afc962476..b6b6f622495 100644 --- a/src/aztec/ecc/groups/affine_element_impl.hpp +++ b/src/aztec/ecc/groups/affine_element_impl.hpp @@ -136,13 +136,6 @@ constexpr bool affine_element::operator==(const affine_element& other template constexpr bool affine_element::operator>(const affine_element& other) const noexcept { - // We are setting point at infinity to always be the lowest element - if (is_point_at_infinity()) { - return false; - } else if (other.is_point_at_infinity()) { - return true; - } - if (x > other.x) { return true; } else if (x == other.x && y > other.y) { diff --git a/src/aztec/numeric/uint256/uint256.hpp b/src/aztec/numeric/uint256/uint256.hpp index d446eac47b3..bafa6f626ec 100644 --- a/src/aztec/numeric/uint256/uint256.hpp +++ b/src/aztec/numeric/uint256/uint256.hpp @@ -147,8 +147,6 @@ class alignas(32) uint256_t { uint64_t data[4]; - constexpr std::pair divmod(const uint256_t& b) const; - private: constexpr std::pair mul_wide(const uint64_t a, const uint64_t b) const; constexpr std::pair addc(const uint64_t a, const uint64_t b, const uint64_t carry_in) const; @@ -164,6 +162,7 @@ class alignas(32) uint256_t { const uint64_t b, const uint64_t c, const uint64_t carry_in) const; + constexpr std::pair divmod(const uint256_t& b) const; }; inline std::ostream& operator<<(std::ostream& os, uint256_t const& a) diff --git a/src/aztec/stdlib/primitives/bigfield/bigfield.hpp b/src/aztec/stdlib/primitives/bigfield/bigfield.hpp index 4d16a684d3c..0fba46b0cc9 100644 --- a/src/aztec/stdlib/primitives/bigfield/bigfield.hpp +++ b/src/aztec/stdlib/primitives/bigfield/bigfield.hpp @@ -47,11 +47,6 @@ template class bigfield { maximum_value = DEFAULT_MAXIMUM_LIMB; } } - friend std::ostream& operator<<(std::ostream& os, const Limb& a) - { - os << "{ " << a.element << " < " << a.maximum_value << " }"; - return os; - } Limb(const Limb& other) = default; Limb(Limb&& other) = default; Limb& operator=(const Limb& other) = default; @@ -217,7 +212,6 @@ template class bigfield { const std::vector& to_sub, bool enable_divisor_nz_check = false); - static bigfield sum(const std::vector& terms); static bigfield internal_div(const std::vector& numerators, const bigfield& denominator, bool check_for_zero); diff --git a/src/aztec/stdlib/primitives/bigfield/bigfield.test.cpp b/src/aztec/stdlib/primitives/bigfield/bigfield.test.cpp index 094474ec554..4fca9cae57d 100644 --- a/src/aztec/stdlib/primitives/bigfield/bigfield.test.cpp +++ b/src/aztec/stdlib/primitives/bigfield/bigfield.test.cpp @@ -42,22 +42,6 @@ template class stdlib_bigfield : public testing::Test { typedef typename bn254::witness_ct witness_ct; public: - // The bug happens when we are applying the CRT formula to a*b < r, which can happen when using the division - // operator - static void test_fuzzer_bug() - { - auto composer = Composer(); - uint256_t value(2); - fq_ct tval = fq_ct::create_from_u512_as_witness(&composer, value); - fq_ct tval1 = tval - tval; - fq_ct tval2 = tval1 / tval; - (void)tval2; - auto prover = composer.create_prover(); - auto verifier = composer.create_verifier(); - waffle::plonk_proof proof = prover.construct_proof(); - bool proof_result = verifier.verify_proof(proof); - EXPECT_EQ(proof_result, true); - } static void test_bad_mul() { @@ -780,17 +764,6 @@ template class stdlib_bigfield : public testing::Test { bool proof_result = verifier.verify_proof(proof); EXPECT_EQ(proof_result, true); } - - static void test_conditional_select_regression() - { - auto composer = Composer(); - barretenberg::fq a(0); - barretenberg::fq b(1); - fq_ct a_ct(&composer, a); - fq_ct b_ct(&composer, b); - fq_ct selected = a_ct.conditional_select(b_ct, typename bn254::bool_ct(&composer, true)); - EXPECT_EQ(barretenberg::fq((selected.get_value() % uint512_t(barretenberg::fq::modulus)).lo), b); - } }; // Define types for which the above tests will be constructed. @@ -801,10 +774,7 @@ typedef testing::Types bigfield bigfield::operator-(const if (other.is_constant()) { uint512_t right = other.get_value() % modulus_u512; - uint512_t neg_right = (modulus_u512 - right) % modulus_u512; + uint512_t neg_right = modulus_u512 - right; return operator+(bigfield(ctx, uint256_t(neg_right.lo))); } @@ -472,30 +472,6 @@ template bigfield bigfield::operator/(const return internal_div({ *this }, other, false); } -/** - * @brief Create constraints for summing these terms - * - * @tparam C - * @tparam T - * @param terms - * @return The sum of terms - */ -template bigfield bigfield::sum(const std::vector& terms) -{ - ASSERT(terms.size() > 0); - - if (terms.size() == 1) { - return terms[0]; - } - std::vector halved; - for (size_t i = 0; i < terms.size() / 2; i++) { - halved.push_back(terms[2 * i] + terms[2 * i + 1]); - } - if (terms.size() & 1) { - halved.push_back(terms[terms.size() - 1]); - } - return sum(halved); -} /** * Division of a sum with an optional check if divisor is zero. Should not be used outside of class. @@ -514,18 +490,8 @@ bigfield bigfield::internal_div(const std::vector& numerat if (numerators.size() == 0) { return bigfield(nullptr, uint256_t(0)); } - // This is a temporary fix for completeness bug - // TODO: Try to implement a different base formula to get rid of the this summing behaviour, which may cost us - // gates. - if (numerators.size() > 1) { - std::vector single_numerator; - single_numerator.push_back(sum(numerators)); - return internal_div(single_numerator, denominator, check_for_zero); - } denominator.reduction_check(); - // This is a temporary fix for completeness bug - numerators[0].self_reduce(); C* ctx = denominator.context; uint512_t numerator_values(0); @@ -539,7 +505,6 @@ bigfield bigfield::internal_div(const std::vector& numerat // a / b = c // => c * b = a mod p - const uint1024_t left = uint1024_t(numerator_values); const uint1024_t right = uint1024_t(denominator.get_value()); const uint1024_t modulus(target_basis.modulus); @@ -683,25 +648,16 @@ template bigfield bigfield::sqradd(const st const uint1024_t add_right(add_values); const uint1024_t modulus(target_basis.modulus); - bigfield remainder; - bigfield quotient; - if (is_constant()) { - if (add_constant) { - - const auto [quotient_1024, remainder_1024] = (left * right + add_right).divmod(modulus); - remainder = bigfield(ctx, uint256_t(remainder_1024.lo.lo)); - return remainder; - } else { + const auto [quotient_1024, remainder_1024] = (left * right + add_right).divmod(modulus); - const auto [quotient_1024, remainder_1024] = (left * right).divmod(modulus); - std::vector new_to_add; - for (auto& add_element : to_add) { - new_to_add.push_back(add_element); - } + const uint512_t quotient_value = quotient_1024.lo; + const uint512_t remainder_value = remainder_1024.lo; - new_to_add.push_back(bigfield(ctx, remainder_1024.lo.lo)); - return sum(new_to_add); - } + bigfield remainder; + bigfield quotient; + if (is_constant() && add_constant) { + remainder = bigfield(ctx, uint256_t(remainder_value.lo)); + return remainder; } else { // Check the quotient fits the range proof @@ -712,17 +668,14 @@ template bigfield bigfield::sqradd(const st self_reduce(); return sqradd(to_add); } - const auto [quotient_1024, remainder_1024] = (left * right + add_right).divmod(modulus); - uint512_t quotient_value = quotient_1024.lo; - uint256_t remainder_value = remainder_1024.lo.lo; quotient = bigfield(witness_t(ctx, fr(quotient_value.slice(0, NUM_LIMB_BITS * 2).lo)), witness_t(ctx, fr(quotient_value.slice(NUM_LIMB_BITS * 2, NUM_LIMB_BITS * 4).lo)), false, num_quotient_bits); remainder = bigfield( - witness_t(ctx, fr(remainder_value.slice(0, NUM_LIMB_BITS * 2))), - witness_t(ctx, fr(remainder_value.slice(NUM_LIMB_BITS * 2, NUM_LIMB_BITS * 3 + NUM_LAST_LIMB_BITS)))); + witness_t(ctx, fr(remainder_value.slice(0, NUM_LIMB_BITS * 2).lo)), + witness_t(ctx, fr(remainder_value.slice(NUM_LIMB_BITS * 2, NUM_LIMB_BITS * 3 + NUM_LAST_LIMB_BITS).lo))); }; unsafe_evaluate_square_add(*this, to_add, quotient, remainder); return remainder; @@ -771,6 +724,7 @@ bigfield bigfield::madd(const bigfield& to_mul, const std::vector to_mul.get_maximum_value()) { self_reduce(); @@ -947,129 +901,48 @@ bigfield bigfield::mult_madd(const std::vector& mul_left, const size_t number_of_products = mul_left.size(); + // First we need to check if it is possible to reduce the products enough + const uint1024_t modulus(target_basis.modulus); uint1024_t worst_case_product_sum(0); - uint1024_t add_right_constant_sum(0); + uint1024_t add_right(0); + uint1024_t add_right_maximum(0); - // First we do all constant optimizations + // Compute the sum of added values (we don't force-reduce these) + // We use add_right later for computing the quotient and remainder bool add_constant = true; - std::vector new_to_add; - for (const auto& add_element : to_add) { add_element.reduction_check(); - if (add_element.is_constant()) { - add_right_constant_sum += uint1024_t(add_element.get_value()); - } else { - add_constant = false; - new_to_add.push_back(add_element); - } - } - - // Compute the product sum - // Optimize constant use - uint1024_t sum_of_constant_products(0); - std::vector new_input_left; - std::vector new_input_right; - bool product_sum_constant = true; - for (size_t i = 0; i < number_of_products; i++) { - if (mutable_mul_left[i].is_constant() && mutable_mul_right[i].is_constant()) { - // If constant, just add to the sum - sum_of_constant_products += - uint1024_t(mutable_mul_left[i].get_value()) * uint1024_t(mutable_mul_right[i].get_value()); - } else { - // If not, add to nonconstant sum and remember the elements - new_input_left.push_back(mutable_mul_left[i]); - new_input_right.push_back(mutable_mul_right[i]); - product_sum_constant = false; - } - } - - C* ctx = nullptr; - // Search through all multiplicands on the left - for (auto& el : mutable_mul_left) { - if (el.context) { - ctx = el.context; - break; - } - } - // And on the right - if (!ctx) { - for (auto& el : mutable_mul_right) { - if (el.context) { - ctx = el.context; - break; - } - } - } - if (product_sum_constant) { - if (add_constant) { - // Simply return the constant, no need unsafe_multiply_add - const auto [quotient_1024, remainder_1024] = - (sum_of_constant_products + add_right_constant_sum).divmod(modulus); - ASSERT(!fix_remainder_to_zero || remainder_1024 == 0); - return bigfield(ctx, uint256_t(remainder_1024.lo.lo)); - } else { - const auto [quotient_1024, remainder_1024] = - (sum_of_constant_products + add_right_constant_sum).divmod(modulus); - uint256_t remainder_value = remainder_1024.lo.lo; - bigfield result; - if (remainder_value == uint256_t(0)) { - // No need to add extra term to new_to_add - result = sum(new_to_add); - } else { - // Add the constant term - new_to_add.push_back(bigfield(ctx, uint256_t(remainder_value))); - result = sum(new_to_add); - } - if (fix_remainder_to_zero) { - result.self_reduce(); - result.assert_equal(zero()); - } - return result; - } - } - - // Now that we know that there is at least 1 non-constant multiplication, we can start estimating reductions, etc - - // Compute the constant term we're adding - const auto [_, constant_part_remainder_1024] = (sum_of_constant_products + add_right_constant_sum).divmod(modulus); - const uint256_t constant_part_remainder_256 = constant_part_remainder_1024.lo.lo; - - if (constant_part_remainder_256 != uint256_t(0)) { - new_to_add.push_back(bigfield(ctx, constant_part_remainder_256)); - } - // Compute added sum - uint1024_t add_right_final_sum(0); - uint1024_t add_right_maximum(0); - for (const auto& add_element : new_to_add) { - // Technically not needed, but better to leave just in case - add_element.reduction_check(); - add_right_final_sum += uint1024_t(add_element.get_value()); - + add_right += uint1024_t(add_element.get_value()); add_right_maximum += uint1024_t(add_element.get_maximum_value()); + add_constant = add_constant && (add_element.is_constant()); } - const size_t final_number_of_products = new_input_left.size(); - - // We need to check if it is possible to reduce the products enough - worst_case_product_sum = uint1024_t(final_number_of_products) * uint1024_t(DEFAULT_MAXIMUM_REMAINDER) * - uint1024_t(DEFAULT_MAXIMUM_REMAINDER); + worst_case_product_sum = + uint1024_t(number_of_products) * uint1024_t(DEFAULT_MAXIMUM_REMAINDER) * uint1024_t(DEFAULT_MAXIMUM_REMAINDER); // Check that we can actually reduce the products enough, this assert will probably never get triggered ASSERT((worst_case_product_sum + add_right_maximum) < get_maximum_crt_product()); - // We've collapsed all constants, checked if we can compute the sum of products in the worst case, time to check if - // we need to reduce something - perform_reductions_for_mult_madd(new_input_left, new_input_right, new_to_add); - uint1024_t sum_of_products_final(0); - for (size_t i = 0; i < final_number_of_products; i++) { - sum_of_products_final += uint1024_t(new_input_left[i].get_value()) * uint1024_t(new_input_right[i].get_value()); - } + perform_reductions_for_mult_madd(mutable_mul_left, mutable_mul_right, to_add); // Get the number of range proof bits for the quotient const size_t num_quotient_bits = get_quotient_max_bits({ DEFAULT_MAXIMUM_REMAINDER }); + // TODO: Could probably search through all + C* ctx = mutable_mul_left[0].context ? mutable_mul_left[0].context : mutable_mul_right[0].context; + + // Compute the product sum + // And check if all the multiplied values are constant + uint1024_t product_sum(0); + bool product_sum_constant = true; + for (size_t i = 0; i < number_of_products; i++) { + product_sum += uint1024_t(mutable_mul_left[i].get_value()) * uint1024_t(mutable_mul_right[i].get_value()); + product_sum_constant = + product_sum_constant && mutable_mul_left[i].is_constant() && mutable_mul_right[i].is_constant(); + } + // Compute the quotient and remainder - const auto [quotient_1024, remainder_1024] = (sum_of_products_final + add_right_final_sum).divmod(modulus); + const auto [quotient_1024, remainder_1024] = (product_sum + add_right).divmod(modulus); // If we are establishing an identity and the remainder has to be zero, we need to check, that it actually is @@ -1082,23 +955,31 @@ bigfield bigfield::mult_madd(const std::vector& mul_left, bigfield remainder; bigfield quotient; - // Constrain quotient to mitigate CRT overflow attacks - quotient = bigfield(witness_t(ctx, fr(quotient_value.slice(0, NUM_LIMB_BITS * 2).lo)), - witness_t(ctx, fr(quotient_value.slice(NUM_LIMB_BITS * 2, NUM_LIMB_BITS * 4).lo)), - false, - num_quotient_bits); - if (fix_remainder_to_zero) { - remainder = zero(); - // remainder needs to be defined as wire value and not selector values to satisfy - // UltraPlonk's bigfield custom gates - remainder.convert_constant_to_witness(ctx); + // If all was constant, just create a new constant + if (product_sum_constant && add_constant) { + // We don't check fix_remainder_to_zero here because it makes absolutely no sense + remainder = bigfield(ctx, uint256_t(remainder_value.lo)); + return remainder; } else { - remainder = bigfield( - witness_t(ctx, fr(remainder_value.slice(0, NUM_LIMB_BITS * 2).lo)), - witness_t(ctx, fr(remainder_value.slice(NUM_LIMB_BITS * 2, NUM_LIMB_BITS * 3 + NUM_LAST_LIMB_BITS).lo))); - } + // Constrain quotient to mitigate CRT overflow attacks + quotient = bigfield(witness_t(ctx, fr(quotient_value.slice(0, NUM_LIMB_BITS * 2).lo)), + witness_t(ctx, fr(quotient_value.slice(NUM_LIMB_BITS * 2, NUM_LIMB_BITS * 4).lo)), + false, + num_quotient_bits); + if (fix_remainder_to_zero) { + remainder = zero(); + // remainder needs to be defined as wire value and not selector values to satisfy + // UltraPlonk's bigfield custom gates + remainder.convert_constant_to_witness(ctx); + } else { + remainder = bigfield( + witness_t(ctx, fr(remainder_value.slice(0, NUM_LIMB_BITS * 2).lo)), + witness_t(ctx, + fr(remainder_value.slice(NUM_LIMB_BITS * 2, NUM_LIMB_BITS * 3 + NUM_LAST_LIMB_BITS).lo))); + } + }; - unsafe_evaluate_multiple_multiply_add(new_input_left, new_input_right, new_to_add, quotient, { remainder }); + unsafe_evaluate_multiple_multiply_add(mutable_mul_left, mutable_mul_right, to_add, quotient, { remainder }); return remainder; } @@ -1296,9 +1177,9 @@ bigfield bigfield::conditional_select(const bigfield& other, const b { if (is_constant() && other.is_constant() && predicate.is_constant()) { if (predicate.get_value()) { - return other; + return *this; } - return *this; + return other; } C* ctx = context ? context : (other.context ? other.context : predicate.context); @@ -1394,9 +1275,9 @@ template void bigfield::assert_is_in_field() cons bool borrow_0_value = value.slice(0, NUM_LIMB_BITS) > modulus_minus_one_0; bool borrow_1_value = - (value.slice(NUM_LIMB_BITS, NUM_LIMB_BITS * 2) + uint256_t(borrow_0_value)) > (modulus_minus_one_1); + (value.slice(NUM_LIMB_BITS, NUM_LIMB_BITS * 2) - uint256_t(borrow_0_value)) > modulus_minus_one_1; bool borrow_2_value = - (value.slice(NUM_LIMB_BITS * 2, NUM_LIMB_BITS * 3) + uint256_t(borrow_1_value)) > (modulus_minus_one_2); + (value.slice(NUM_LIMB_BITS * 2, NUM_LIMB_BITS * 3) - uint256_t(borrow_1_value)) > modulus_minus_one_2; field_t modulus_0(context, modulus_minus_one_0); field_t modulus_1(context, modulus_minus_one_1); @@ -1489,7 +1370,6 @@ template void bigfield::assert_equal(const bigfie // it's non-zero mod r template void bigfield::assert_is_not_equal(const bigfield& other) const { - // Why would we use this for 2 constants? Turns out, in biggroup const auto get_overload_count = [target_modulus = modulus_u512](const uint512_t& maximum_value) { uint512_t target = target_modulus; size_t overload_count = 0; @@ -1603,42 +1483,42 @@ void bigfield::unsafe_evaluate_multiply_add(const bigfield& input_left, C* ctx = left.context ? left.context : to_mul.context; - uint512_t max_b0 = (left.binary_basis_limbs[1].maximum_value * to_mul.binary_basis_limbs[0].maximum_value); + uint256_t max_b0 = (left.binary_basis_limbs[1].maximum_value * to_mul.binary_basis_limbs[0].maximum_value); max_b0 += (neg_modulus_limbs_u256[1] * quotient.binary_basis_limbs[0].maximum_value); - uint512_t max_b1 = (left.binary_basis_limbs[0].maximum_value * to_mul.binary_basis_limbs[1].maximum_value); + uint256_t max_b1 = (left.binary_basis_limbs[0].maximum_value * to_mul.binary_basis_limbs[1].maximum_value); max_b1 += (neg_modulus_limbs_u256[0] * quotient.binary_basis_limbs[1].maximum_value); - uint512_t max_c0 = (left.binary_basis_limbs[1].maximum_value * to_mul.binary_basis_limbs[1].maximum_value); + uint256_t max_c0 = (left.binary_basis_limbs[1].maximum_value * to_mul.binary_basis_limbs[1].maximum_value); max_c0 += (neg_modulus_limbs_u256[1] * quotient.binary_basis_limbs[1].maximum_value); - uint512_t max_c1 = (left.binary_basis_limbs[2].maximum_value * to_mul.binary_basis_limbs[0].maximum_value); + uint256_t max_c1 = (left.binary_basis_limbs[2].maximum_value * to_mul.binary_basis_limbs[0].maximum_value); max_c1 += (neg_modulus_limbs_u256[2] * quotient.binary_basis_limbs[0].maximum_value); - uint512_t max_c2 = (left.binary_basis_limbs[0].maximum_value * to_mul.binary_basis_limbs[2].maximum_value); + uint256_t max_c2 = (left.binary_basis_limbs[0].maximum_value * to_mul.binary_basis_limbs[2].maximum_value); max_c2 += (neg_modulus_limbs_u256[0] * quotient.binary_basis_limbs[2].maximum_value); - uint512_t max_d0 = (left.binary_basis_limbs[3].maximum_value * to_mul.binary_basis_limbs[0].maximum_value); + uint256_t max_d0 = (left.binary_basis_limbs[3].maximum_value * to_mul.binary_basis_limbs[0].maximum_value); max_d0 += (neg_modulus_limbs_u256[3] * quotient.binary_basis_limbs[0].maximum_value); - uint512_t max_d1 = (left.binary_basis_limbs[2].maximum_value * to_mul.binary_basis_limbs[1].maximum_value); + uint256_t max_d1 = (left.binary_basis_limbs[2].maximum_value * to_mul.binary_basis_limbs[1].maximum_value); max_d1 += (neg_modulus_limbs_u256[2] * quotient.binary_basis_limbs[1].maximum_value); - uint512_t max_d2 = (left.binary_basis_limbs[1].maximum_value * to_mul.binary_basis_limbs[2].maximum_value); + uint256_t max_d2 = (left.binary_basis_limbs[1].maximum_value * to_mul.binary_basis_limbs[2].maximum_value); max_d2 += (neg_modulus_limbs_u256[1] * quotient.binary_basis_limbs[2].maximum_value); - uint512_t max_d3 = (left.binary_basis_limbs[0].maximum_value * to_mul.binary_basis_limbs[3].maximum_value); + uint256_t max_d3 = (left.binary_basis_limbs[0].maximum_value * to_mul.binary_basis_limbs[3].maximum_value); max_d3 += (neg_modulus_limbs_u256[0] * quotient.binary_basis_limbs[3].maximum_value); - uint512_t max_r0 = left.binary_basis_limbs[0].maximum_value * to_mul.binary_basis_limbs[0].maximum_value; + uint256_t max_r0 = left.binary_basis_limbs[0].maximum_value * to_mul.binary_basis_limbs[0].maximum_value; max_r0 += (neg_modulus_limbs_u256[0] * quotient.binary_basis_limbs[0].maximum_value); - const uint512_t max_r1 = max_b0 + max_b1; - const uint512_t max_r2 = max_c0 + max_c1 + max_c2; - const uint512_t max_r3 = max_d0 + max_d1 + max_d2 + max_d3; + const uint256_t max_r1 = max_b0 + max_b1; + const uint256_t max_r2 = max_c0 + max_c1 + max_c2; + const uint256_t max_r3 = max_d0 + max_d1 + max_d2 + max_d3; - uint512_t max_a0(0); - uint512_t max_a1(0); + uint256_t max_a0(0); + uint256_t max_a1(0); for (size_t i = 0; i < to_add.size(); ++i) { max_a0 += to_add[i].binary_basis_limbs[0].maximum_value + (to_add[i].binary_basis_limbs[1].maximum_value << NUM_LIMB_BITS); max_a1 += to_add[i].binary_basis_limbs[2].maximum_value + (to_add[i].binary_basis_limbs[3].maximum_value << NUM_LIMB_BITS); } - const uint512_t max_lo = max_r0 + (max_r1 << NUM_LIMB_BITS) + max_a0; - const uint512_t max_hi = max_r2 + (max_r3 << NUM_LIMB_BITS) + max_a1; + const uint256_t max_lo = max_r0 + (max_r1 << NUM_LIMB_BITS) + max_a0; + const uint256_t max_hi = max_r2 + (max_r3 << NUM_LIMB_BITS) + max_a1; uint64_t max_lo_bits = (max_lo.get_msb() + 1); uint64_t max_hi_bits = max_hi.get_msb() + 1; @@ -1745,20 +1625,13 @@ void bigfield::unsafe_evaluate_multiply_add(const bigfield& input_left, ctx->decompose_into_default_range(carry_hi.witness_index, static_cast(carry_hi_msb)); } else { - if ((carry_hi_msb + carry_lo_msb) < field_t::modulus.get_msb()) { - field_t carry_combined = carry_lo + (carry_hi * carry_lo_shift); - carry_combined = carry_combined.normalize(); - const auto accumulators = ctx->decompose_into_base4_accumulators( - carry_combined.witness_index, static_cast(carry_lo_msb + carry_hi_msb)); - field_t accumulator_midpoint = - field_t::from_witness_index(ctx, accumulators[static_cast((carry_hi_msb / 2) - 1)]); - carry_hi.assert_equal(accumulator_midpoint, "bigfield multiply range check failed"); - } else { - carry_lo = carry_lo.normalize(); - carry_hi = carry_hi.normalize(); - ctx->decompose_into_base4_accumulators(carry_lo.witness_index, static_cast(carry_lo_msb)); - ctx->decompose_into_base4_accumulators(carry_hi.witness_index, static_cast(carry_hi_msb)); - } + field_t carry_combined = carry_lo + (carry_hi * carry_lo_shift); + carry_combined = carry_combined.normalize(); + const auto accumulators = ctx->decompose_into_base4_accumulators( + carry_combined.witness_index, static_cast(carry_lo_msb + carry_hi_msb)); + field_t accumulator_midpoint = + field_t::from_witness_index(ctx, accumulators[static_cast((carry_hi_msb / 2) - 1)]); + carry_hi.assert_equal(accumulator_midpoint, "bigfield multiply range check failed"); } } @@ -1801,23 +1674,23 @@ void bigfield::unsafe_evaluate_multiple_multiply_add(const std::vector(max_lo_temp, max_hi_temp); + uint256_t max_b0_inner = (left.binary_basis_limbs[1].maximum_value * right.binary_basis_limbs[0].maximum_value); + uint256_t max_b1_inner = (left.binary_basis_limbs[0].maximum_value * right.binary_basis_limbs[1].maximum_value); + uint256_t max_c0_inner = (left.binary_basis_limbs[1].maximum_value * right.binary_basis_limbs[1].maximum_value); + uint256_t max_c1_inner = (left.binary_basis_limbs[2].maximum_value * right.binary_basis_limbs[0].maximum_value); + uint256_t max_c2_inner = (left.binary_basis_limbs[0].maximum_value * right.binary_basis_limbs[2].maximum_value); + uint256_t max_d0_inner = (left.binary_basis_limbs[3].maximum_value * right.binary_basis_limbs[0].maximum_value); + uint256_t max_d1_inner = (left.binary_basis_limbs[2].maximum_value * right.binary_basis_limbs[1].maximum_value); + uint256_t max_d2_inner = (left.binary_basis_limbs[1].maximum_value * right.binary_basis_limbs[2].maximum_value); + uint256_t max_d3_inner = (left.binary_basis_limbs[0].maximum_value * right.binary_basis_limbs[3].maximum_value); + uint256_t max_r0_inner = left.binary_basis_limbs[0].maximum_value * right.binary_basis_limbs[0].maximum_value; + + const uint256_t max_r1_inner = max_b0_inner + max_b1_inner; + const uint256_t max_r2_inner = max_c0_inner + max_c1_inner + max_c2_inner; + const uint256_t max_r3_inner = max_d0_inner + max_d1_inner + max_d2_inner + max_d3_inner; + const uint256_t max_lo_temp = max_r0_inner + (max_r1_inner << NUM_LIMB_BITS); + const uint256_t max_hi_temp = max_r2_inner + (max_r3_inner << NUM_LIMB_BITS); + return std::pair(max_lo_temp, max_hi_temp); }; /** @@ -1827,37 +1700,37 @@ void bigfield::unsafe_evaluate_multiple_multiply_add(const std::vector::unsafe_evaluate_multiple_multiply_add(const std::vectordecompose_into_default_range(carry_hi.witness_index, static_cast(carry_hi_msb)); } else { - if ((carry_hi_msb + carry_lo_msb) < field_t::modulus.get_msb()) { - field_t carry_combined = carry_lo + (carry_hi * carry_lo_shift); - carry_combined = carry_combined.normalize(); - const auto accumulators = ctx->decompose_into_base4_accumulators( - carry_combined.witness_index, static_cast(carry_lo_msb + carry_hi_msb)); - field_t accumulator_midpoint = - field_t::from_witness_index(ctx, accumulators[static_cast((carry_hi_msb / 2) - 1)]); - carry_hi.assert_equal(accumulator_midpoint, "bigfield multiply range check failed"); - } else { - carry_lo = carry_lo.normalize(); - carry_hi = carry_hi.normalize(); - ctx->decompose_into_base4_accumulators(carry_lo.witness_index, static_cast(carry_lo_msb)); - ctx->decompose_into_base4_accumulators(carry_hi.witness_index, static_cast(carry_hi_msb)); - } + field_t carry_combined = carry_lo + (carry_hi * carry_lo_shift); + carry_combined = carry_combined.normalize(); + const auto accumulators = ctx->decompose_into_base4_accumulators( + carry_combined.witness_index, static_cast(carry_lo_msb + carry_hi_msb)); + field_t accumulator_midpoint = + field_t::from_witness_index(ctx, accumulators[static_cast((carry_hi_msb / 2) - 1)]); + carry_hi.assert_equal(accumulator_midpoint, "bigfield multiply range check failed"); } } @@ -2091,38 +1957,38 @@ void bigfield::unsafe_evaluate_square_add(const bigfield& left, } C* ctx = left.context == nullptr ? quotient.context : left.context; - uint512_t max_b0 = (left.binary_basis_limbs[1].maximum_value * left.binary_basis_limbs[0].maximum_value); + uint256_t max_b0 = (left.binary_basis_limbs[1].maximum_value * left.binary_basis_limbs[0].maximum_value); max_b0 += (neg_modulus_limbs_u256[1] << NUM_LIMB_BITS); max_b0 += max_b0; - uint512_t max_c0 = (left.binary_basis_limbs[1].maximum_value * left.binary_basis_limbs[1].maximum_value); + uint256_t max_c0 = (left.binary_basis_limbs[1].maximum_value * left.binary_basis_limbs[1].maximum_value); max_c0 += (neg_modulus_limbs_u256[1] << NUM_LIMB_BITS); - uint512_t max_c1 = (left.binary_basis_limbs[2].maximum_value * left.binary_basis_limbs[0].maximum_value); + uint256_t max_c1 = (left.binary_basis_limbs[2].maximum_value * left.binary_basis_limbs[0].maximum_value); max_c1 += (neg_modulus_limbs_u256[2] << NUM_LIMB_BITS); max_c1 += max_c1; - uint512_t max_d0 = (left.binary_basis_limbs[3].maximum_value * left.binary_basis_limbs[0].maximum_value); + uint256_t max_d0 = (left.binary_basis_limbs[3].maximum_value * left.binary_basis_limbs[0].maximum_value); max_d0 += (neg_modulus_limbs_u256[3] << NUM_LIMB_BITS); max_d0 += max_d0; - uint512_t max_d1 = (left.binary_basis_limbs[2].maximum_value * left.binary_basis_limbs[1].maximum_value); + uint256_t max_d1 = (left.binary_basis_limbs[2].maximum_value * left.binary_basis_limbs[1].maximum_value); max_d1 += (neg_modulus_limbs_u256[2] << NUM_LIMB_BITS); max_d1 += max_d1; - uint512_t max_r0 = left.binary_basis_limbs[0].maximum_value * left.binary_basis_limbs[0].maximum_value; + uint256_t max_r0 = left.binary_basis_limbs[0].maximum_value * left.binary_basis_limbs[0].maximum_value; max_r0 += (neg_modulus_limbs_u256[0] << NUM_LIMB_BITS); - const uint512_t max_r1 = max_b0; - const uint512_t max_r2 = max_c0 + max_c1; - const uint512_t max_r3 = max_d0 + max_d1; + const uint256_t max_r1 = max_b0; + const uint256_t max_r2 = max_c0 + max_c1; + const uint256_t max_r3 = max_d0 + max_d1; - uint512_t max_a0(0); - uint512_t max_a1(1); + uint256_t max_a0(0); + uint256_t max_a1(1); for (size_t i = 0; i < to_add.size(); ++i) { max_a0 += to_add[i].binary_basis_limbs[0].maximum_value + (to_add[i].binary_basis_limbs[1].maximum_value << NUM_LIMB_BITS); max_a1 += to_add[i].binary_basis_limbs[2].maximum_value + (to_add[i].binary_basis_limbs[3].maximum_value << NUM_LIMB_BITS); } - const uint512_t max_lo = max_r0 + (max_r1 << NUM_LIMB_BITS) + max_a0; - const uint512_t max_hi = max_r2 + (max_r3 << NUM_LIMB_BITS) + max_a1; + const uint256_t max_lo = max_r0 + (max_r1 << NUM_LIMB_BITS) + max_a0; + const uint256_t max_hi = max_r2 + (max_r3 << NUM_LIMB_BITS) + max_a1; uint64_t max_lo_bits = max_lo.get_msb() + 1; uint64_t max_hi_bits = max_hi.get_msb() + 1; @@ -2192,13 +2058,14 @@ void bigfield::unsafe_evaluate_square_add(const bigfield& left, barretenberg::fr neg_prime = -barretenberg::fr(uint256_t(target_basis.modulus)); field_t linear_terms = -remainder.prime_basis_limb; if (to_add.size() >= 2) { - for (size_t i = 0; i < to_add.size() / 2; i += 1) { - linear_terms = linear_terms.add_two(to_add[2 * i].prime_basis_limb, to_add[2 * i + 1].prime_basis_limb); + for (size_t i = 0; i < to_add.size(); i += 2) { + linear_terms = linear_terms.add_two(to_add[i].prime_basis_limb, to_add[i + 1].prime_basis_limb); } } if ((to_add.size() & 1UL) == 1UL) { linear_terms += to_add[to_add.size() - 1].prime_basis_limb; } + field_t::evaluate_polynomial_identity( left.prime_basis_limb, left.prime_basis_limb, quotient.prime_basis_limb * neg_prime, linear_terms); @@ -2213,20 +2080,13 @@ void bigfield::unsafe_evaluate_square_add(const bigfield& left, ctx->decompose_into_default_range(carry_hi.witness_index, static_cast(carry_hi_msb)); } else { - if ((carry_hi_msb + carry_lo_msb) < field_t::modulus.get_msb()) { - field_t carry_combined = carry_lo + (carry_hi * carry_lo_shift); - carry_combined = carry_combined.normalize(); - const auto accumulators = ctx->decompose_into_base4_accumulators( - carry_combined.witness_index, static_cast(carry_lo_msb + carry_hi_msb)); - field_t accumulator_midpoint = - field_t::from_witness_index(ctx, accumulators[static_cast((carry_hi_msb / 2) - 1)]); - carry_hi.assert_equal(accumulator_midpoint, "bigfield multiply range check failed"); - } else { - carry_lo = carry_lo.normalize(); - carry_hi = carry_hi.normalize(); - ctx->decompose_into_base4_accumulators(carry_lo.witness_index, static_cast(carry_lo_msb)); - ctx->decompose_into_base4_accumulators(carry_hi.witness_index, static_cast(carry_hi_msb)); - } + field_t carry_combined = carry_lo + (carry_hi * carry_lo_shift); + carry_combined = carry_combined.normalize(); + const auto accumulators = ctx->decompose_into_base4_accumulators( + carry_combined.witness_index, static_cast(carry_lo_msb + carry_hi_msb)); + field_t accumulator_midpoint = + field_t::from_witness_index(ctx, accumulators[static_cast((carry_hi_msb / 2) - 1)]); + carry_hi.assert_equal(accumulator_midpoint, "bigfield multiply range check failed"); } } @@ -2283,14 +2143,10 @@ std::pair bigfield::get_quotient_reduction_info(const std::v if (mul_product_overflows_crt_modulus(as_max, bs_max, to_add)) { return std::pair(true, 0); } + const size_t num_quotient_bits = get_quotient_max_bits(remainders_max); - std::vector to_add_max; - for (auto& added_element : to_add) { - to_add_max.push_back(added_element.get_maximum_value()); - } // Get maximum value of quotient - const uint512_t maximum_quotient = compute_maximum_quotient_value(as_max, bs_max, to_add_max); - + const uint512_t maximum_quotient = compute_maximum_quotient_value(as_max, bs_max, {}); // Check if the quotient can fit into the range proof if (maximum_quotient >= (uint512_t(1) << num_quotient_bits)) { return std::pair(true, 0); diff --git a/src/aztec/stdlib/primitives/field/field.cpp b/src/aztec/stdlib/primitives/field/field.cpp index 1a526e3a491..22103c0cb02 100644 --- a/src/aztec/stdlib/primitives/field/field.cpp +++ b/src/aztec/stdlib/primitives/field/field.cpp @@ -575,15 +575,14 @@ void field_t::create_range_constraint(const size_t num_bits, st { if (num_bits == 0) { assert_is_zero("0-bit range_constraint on non-zero field_t."); + } + if (is_constant()) { + ASSERT(uint256_t(get_value()).get_msb() < num_bits); } else { - if (is_constant()) { - ASSERT(uint256_t(get_value()).get_msb() < num_bits); + if constexpr (ComposerContext::type == waffle::ComposerType::PLOOKUP) { + context->decompose_into_default_range(normalize().get_witness_index(), num_bits, msg); } else { - if constexpr (ComposerContext::type == waffle::ComposerType::PLOOKUP) { - context->decompose_into_default_range(normalize().get_witness_index(), num_bits, msg); - } else { - context->decompose_into_base4_accumulators(normalize().get_witness_index(), num_bits, msg); - } + context->decompose_into_base4_accumulators(normalize().get_witness_index(), num_bits, msg); } } } diff --git a/src/aztec/stdlib/primitives/safe_uint/safe_uint.cpp b/src/aztec/stdlib/primitives/safe_uint/safe_uint.cpp index 9962b03c0d0..f07d573c9ec 100644 --- a/src/aztec/stdlib/primitives/safe_uint/safe_uint.cpp +++ b/src/aztec/stdlib/primitives/safe_uint/safe_uint.cpp @@ -19,6 +19,7 @@ safe_uint_t safe_uint_t::operator*(const safe_ uint512_t new_max = uint512_t(current_max) * uint512_t(other.current_max); ASSERT(new_max.hi == 0); + return safe_uint_t((value * other.value), new_max.lo, IS_UNSAFE); } @@ -64,14 +65,10 @@ std::array, 3> safe_uint_t::slice( const uint8_t lsb) const { ASSERT(msb >= lsb); - ASSERT(static_cast(msb) <= rollup::MAX_NO_WRAP_INTEGER_BIT_LENGTH); const safe_uint_t lhs = *this; ComposerContext* ctx = lhs.get_context(); const uint256_t value = uint256_t(get_value()); - // This should be caught by the proof itself, but the circuit creator will have now way of knowing where the issue - // is - ASSERT(value < (static_cast(1) << rollup::MAX_NO_WRAP_INTEGER_BIT_LENGTH)); const auto msb_plus_one = uint32_t(msb) + 1; const auto hi_mask = ((uint256_t(1) << (256 - uint32_t(msb))) - 1); const auto hi = (value >> msb_plus_one) & hi_mask; @@ -81,17 +78,11 @@ std::array, 3> safe_uint_t::slice( const auto slice_mask = ((uint256_t(1) << (uint32_t(msb - lsb) + 1)) - 1); const auto slice = (value >> lsb) & slice_mask; - safe_uint_t lo_wit, slice_wit, hi_wit; - if (this->value.is_constant()) { - hi_wit = safe_uint_t(hi); - lo_wit = safe_uint_t(lo); - slice_wit = safe_uint_t(slice); - - } else { - hi_wit = safe_uint_t(witness_t(ctx, hi), rollup::MAX_NO_WRAP_INTEGER_BIT_LENGTH - uint32_t(msb), "hi_wit"); - lo_wit = safe_uint_t(witness_t(ctx, lo), lsb, "lo_wit"); - slice_wit = safe_uint_t(witness_t(ctx, slice), msb_plus_one - lsb, "slice_wit"); - } + + const safe_uint_t hi_wit(witness_t(ctx, hi), rollup::MAX_NO_WRAP_INTEGER_BIT_LENGTH - uint32_t(msb), "hi_wit"); + const safe_uint_t lo_wit(witness_t(ctx, lo), lsb, "lo_wit"); + const safe_uint_t slice_wit(witness_t(ctx, slice), msb_plus_one - lsb, "slice_wit"); + assert_equal(((hi_wit * safe_uint_t(uint256_t(1) << msb_plus_one)) + lo_wit + (slice_wit * safe_uint_t(uint256_t(1) << lsb)))); diff --git a/src/aztec/stdlib/primitives/safe_uint/safe_uint.hpp b/src/aztec/stdlib/primitives/safe_uint/safe_uint.hpp index 93864fef36a..750e9d26bca 100644 --- a/src/aztec/stdlib/primitives/safe_uint/safe_uint.hpp +++ b/src/aztec/stdlib/primitives/safe_uint/safe_uint.hpp @@ -100,7 +100,6 @@ template class safe_uint_t { std::string const& description = "") const { ASSERT(difference_bit_size <= MAX_BIT_NUM); - ASSERT(!(this->value.is_constant() && other.value.is_constant())); field_ct difference_val = this->value - other.value; safe_uint_t difference(difference_val, difference_bit_size, format("subtract: ", description)); // This checks the subtraction is correct for integers without any wraps @@ -111,9 +110,6 @@ template class safe_uint_t { safe_uint_t operator-(const safe_uint_t& other) const { - // We could get a constant underflow - ASSERT(!(this->value.is_constant() && other.value.is_constant() && - static_cast(value.get_value()) < static_cast(other.value.get_value()))); field_ct difference_val = this->value - other.value; safe_uint_t difference(difference_val, (size_t)(current_max.get_msb() + 1), "- operator"); // This checks the subtraction is correct for integers without any wraps @@ -164,7 +160,8 @@ template class safe_uint_t { { ASSERT(this->value.is_constant() == false); uint256_t val = this->value.get_value(); - auto [quotient_val, remainder_val] = val.divmod((uint256_t)other.value.get_value()); + auto quotient_val = (uint256_t)(val / (uint256_t)other.value.get_value()); + auto remainder_val = (uint256_t)(val % (uint256_t)other.value.get_value()); field_ct quotient_field(witness_t(value.context, quotient_val)); field_ct remainder_field(witness_t(value.context, remainder_val)); safe_uint_t quotient( diff --git a/src/aztec/stdlib/primitives/safe_uint/safe_uint.test.cpp b/src/aztec/stdlib/primitives/safe_uint/safe_uint.test.cpp index d7b5aa532be..83a48e877a6 100644 --- a/src/aztec/stdlib/primitives/safe_uint/safe_uint.test.cpp +++ b/src/aztec/stdlib/primitives/safe_uint/safe_uint.test.cpp @@ -260,15 +260,19 @@ TEST(stdlib_safeuint, test_divide_method_quotient_range_too_small_fails) EXPECT_EQ(result.err, "safe_uint_t range constraint failure: divide method quotient: d/c"); } -TEST(stdlib_safeuint, test_divide_remainder_too_large) +TEST(stdlib_safeuint, test_divide_method_remainder_range_too_small_fails) { // test failure when range for remainder too small waffle::TurboComposer composer = waffle::TurboComposer(); field_t a(witness_t(&composer, 5)); + field_t b(witness_t(&composer, 19)); suint_t c(a, 3); - suint_t d((fr::modulus - 1) / 3); - suint_t b; - EXPECT_ANY_THROW(b = c / d); + suint_t d(b, 5); + d = d.divide(c, 3, 1, "d/c"); + + auto result = verify_logic(composer); + EXPECT_FALSE(result.valid); + EXPECT_EQ(result.err, "safe_uint_t range constraint failure: divide method remainder: d/c"); } TEST(stdlib_safeuint, test_divide_method_quotient_remainder_incorrect_fails)