From 722ec5c3dfdc2a5e467528ed94a25677f8800087 Mon Sep 17 00:00:00 2001 From: Zachary James Williamson Date: Wed, 30 Oct 2024 14:54:29 -0700 Subject: [PATCH] feat: faster square roots (#2694) We use the Tonelli-Shanks square root algorithm to perform square roots, required for deriving generators on the Grumpkin curve. Our existing implementation uses a slow algorithm that requires ~1,000 field multiplications per square root. This PR implements a newer algorithm by Bernstein that uses precomputed lookup tables to increase performance https://cr.yp.to/papers/sqroot-20011123-retypeset20220327.pdf --- .../ecc/curves/secp256k1/secp256k1.test.cpp | 11 + .../ecc/fields/field_declarations.hpp | 6 +- .../barretenberg/ecc/fields/field_impl.hpp | 236 ++++++++++++------ .../barretenberg_wasm_main/index.ts | 2 +- .../foundation/src/wasm/wasm_module.ts | 2 +- 5 files changed, 178 insertions(+), 79 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/ecc/curves/secp256k1/secp256k1.test.cpp b/barretenberg/cpp/src/barretenberg/ecc/curves/secp256k1/secp256k1.test.cpp index 60d0f24af2a..0132643ced5 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/curves/secp256k1/secp256k1.test.cpp +++ b/barretenberg/cpp/src/barretenberg/ecc/curves/secp256k1/secp256k1.test.cpp @@ -134,6 +134,17 @@ TEST(secp256k1, TestSqr) } } +TEST(secp256k1, SqrtRandom) +{ + size_t n = 1; + for (size_t i = 0; i < n; ++i) { + secp256k1::fq input = secp256k1::fq::random_element().sqr(); + auto [is_sqr, root] = input.sqrt(); + secp256k1::fq root_test = root.sqr(); + EXPECT_EQ(root_test, input); + } +} + TEST(secp256k1, TestArithmetic) { secp256k1::fq a = secp256k1::fq::random_element(); diff --git a/barretenberg/cpp/src/barretenberg/ecc/fields/field_declarations.hpp b/barretenberg/cpp/src/barretenberg/ecc/fields/field_declarations.hpp index 17f0113101d..af1643bdc1b 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/fields/field_declarations.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/fields/field_declarations.hpp @@ -330,8 +330,10 @@ template struct alignas(32) field { * * @return if the element is a quadratic remainder, if it's not */ - constexpr std::pair sqrt() const noexcept; - + constexpr std::pair sqrt() const noexcept + requires((Params_::modulus_0 & 0x3UL) == 0x3UL); + constexpr std::pair sqrt() const noexcept + requires((Params_::modulus_0 & 0x3UL) != 0x3UL); BB_INLINE constexpr void self_neg() & noexcept; BB_INLINE constexpr void self_to_montgomery_form() & noexcept; diff --git a/barretenberg/cpp/src/barretenberg/ecc/fields/field_impl.hpp b/barretenberg/cpp/src/barretenberg/ecc/fields/field_impl.hpp index 7f92fd299c1..4ab221fc719 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/fields/field_impl.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/fields/field_impl.hpp @@ -452,102 +452,187 @@ template void field::batch_invert(std::span coeffs) noexcept } } +/** + * @brief Implements an optimised variant of Tonelli-Shanks via lookup tables. + * Algorithm taken from https://cr.yp.to/papers/sqroot-20011123-retypeset20220327.pdf + * "FASTER SQUARE ROOTS IN ANNOYING FINITE FIELDS" by D. Bernstein + * Page 5 "Accelerated Discrete Logarithm" + * @tparam T + * @return constexpr field + */ template constexpr field field::tonelli_shanks_sqrt() const noexcept { BB_OP_COUNT_TRACK_NAME("fr::tonelli_shanks_sqrt"); // Tonelli-shanks algorithm begins by finding a field element Q and integer S, // such that (p - 1) = Q.2^{s} - - // We can compute the square root of a, by considering a^{(Q + 1) / 2} = R - // Once we have found such an R, we have - // R^{2} = a^{Q + 1} = a^{Q}a - // If a^{Q} = 1, we have found our square root. - // Otherwise, we have a^{Q} = t, where t is a 2^{s-1}'th root of unity. - // This is because t^{2^{s-1}} = a^{Q.2^{s-1}}. - // We know that (p - 1) = Q.w^{s}, therefore t^{2^{s-1}} = a^{(p - 1) / 2} - // From Euler's criterion, if a is a quadratic residue, a^{(p - 1) / 2} = 1 - // i.e. t^{2^{s-1}} = 1 - - // To proceed with computing our square root, we want to transform t into a smaller subgroup, - // specifically, the (s-2)'th roots of unity. - // We do this by finding some value b,such that - // (t.b^2)^{2^{s-2}} = 1 and R' = R.b - // Finding such a b is trivial, because from Euler's criterion, we know that, - // for any quadratic non-residue z, z^{(p - 1) / 2} = -1 - // i.e. z^{Q.2^{s-1}} = -1 - // => z^Q is a 2^{s-1}'th root of -1 - // => z^{Q^2} is a 2^{s-2}'th root of -1 - // Since t^{2^{s-1}} = 1, we know that t^{2^{s - 2}} = -1 - // => t.z^{Q^2} is a 2^{s - 2}'th root of unity. - - // We can iteratively transform t into ever smaller subgroups, until t = 1. - // At each iteration, we need to find a new value for b, which we can obtain - // by repeatedly squaring z^{Q} - constexpr uint256_t Q = (modulus - 1) >> static_cast(primitive_root_log_size() - 1); - constexpr uint256_t Q_minus_one_over_two = (Q - 1) >> 2; - - // __to_montgomery_form(Q_minus_one_over_two, Q_minus_one_over_two); - field z = coset_generator(0); // the generator is a non-residue - field b = pow(Q_minus_one_over_two); - field r = operator*(b); // r = a^{(Q + 1) / 2} - field t = r * b; // t = a^{(Q - 1) / 2 + (Q + 1) / 2} = a^{Q} + // We can determine s by counting the least significant set bit of `p - 1` + // We pick elements `r, g` such that g = r^Q and r is not a square. + // (the coset generators are all nonresidues and satisfy this condition) + // + // To find the square root of `u`, consider `v = u^(Q - 1 / 2)` + // There exists an integer `e` where uv^2 = g^e (see Theorem 3.1 in paper). + // If `u` is a square, `e` is even and (uvg^{−e/2})^2 = u^2v^2g^e = u^{Q+1}g^{-e} = u + // + // The goal of the algorithm is two fold: + // 1. find `e` given `u` + // 2. compute `sqrt(u) = uvg^{−e/2}` + constexpr uint256_t Q = (modulus - 1) >> static_cast(primitive_root_log_size()); + constexpr uint256_t Q_minus_one_over_two = (Q - 1) >> 1; + field v = pow(Q_minus_one_over_two); + field uv = operator*(v); // uv = u^{(Q + 1) / 2} + // uvv = g^e for some unknown e. Goal is to find e. + field uvv = uv * v; // uvv = u^{(Q - 1) / 2 + (Q + 1) / 2} = u^{Q} // check if t is a square with euler's criterion // if not, we don't have a quadratic residue and a has no square root! - field check = t; + field check = uvv; for (size_t i = 0; i < primitive_root_log_size() - 1; ++i) { check.self_sqr(); } - if (check != one()) { - return zero(); + if (check != 1) { + return 0; } - field t1 = z.pow(Q_minus_one_over_two); - field t2 = t1 * z; - field c = t2 * t1; // z^Q - size_t m = primitive_root_log_size(); + constexpr field g = coset_generator(0).pow(Q); + constexpr field g_inv = coset_generator(0).pow(modulus - 1 - Q); + constexpr size_t root_bits = primitive_root_log_size(); + constexpr size_t table_bits = 6; + constexpr size_t num_tables = root_bits / table_bits + (root_bits % table_bits != 0 ? 1 : 0); + constexpr size_t num_offset_tables = num_tables - 1; + constexpr size_t table_size = static_cast(1UL) << table_bits; + + using GTable = std::array; + constexpr auto get_g_table = [&](const field& h) { + GTable result; + result[0] = 1; + for (size_t i = 1; i < table_size; ++i) { + result[i] = result[i - 1] * h; + } + return result; + }; + constexpr std::array g_tables = [&]() { + field working_base = g_inv; + std::array result; + for (size_t i = 0; i < num_tables; ++i) { + result[i] = get_g_table(working_base); + for (size_t j = 0; j < table_bits; ++j) { + working_base.self_sqr(); + } + } + return result; + }(); + constexpr std::array offset_g_tables = [&]() { + field working_base = g_inv; + for (size_t i = 0; i < root_bits % table_bits; ++i) { + working_base.self_sqr(); + } + std::array result; + for (size_t i = 0; i < num_offset_tables; ++i) { + result[i] = get_g_table(working_base); + for (size_t j = 0; j < table_bits; ++j) { + working_base.self_sqr(); + } + } + return result; + }(); + + constexpr GTable root_table_a = get_g_table(g.pow(1UL << ((num_tables - 1) * table_bits))); + constexpr GTable root_table_b = get_g_table(g.pow(1UL << (root_bits - table_bits))); + // compute uvv^{2^table_bits}, uvv^{2^{table_bits*2}}, ..., uvv^{2^{table_bits*num_tables}} + std::array uvv_powers; + field base = uvv; + for (size_t i = 0; i < num_tables - 1; ++i) { + uvv_powers[i] = base; + for (size_t j = 0; j < table_bits; ++j) { + base.self_sqr(); + } + } + uvv_powers[num_tables - 1] = base; + std::array e_slices; + for (size_t i = 0; i < num_tables; ++i) { + size_t table_index = num_tables - 1 - i; + field target = uvv_powers[table_index]; + for (size_t j = 0; j < i; ++j) { + size_t e_idx = num_tables - 1 - (i - 1) + j; + size_t g_idx = num_tables - 2 - j; + + field g_lookup; + if (j != i - 1) { + g_lookup = offset_g_tables[g_idx - 1][e_slices[e_idx]]; // e1 + } else { + g_lookup = g_tables[g_idx][e_slices[e_idx]]; + } + target *= g_lookup; + } + size_t count = 0; + + if (i == 0) { + for (auto& x : root_table_a) { + if (x == target) { + break; + } + count += 1; + } + } else { + for (auto& x : root_table_b) { + if (x == target) { + break; + } + count += 1; + } + } - while (t != one()) { - size_t i = 0; - field t2m = t; + ASSERT(count != table_size); + e_slices[table_index] = count; + } - // find the smallest value of m, such that t^{2^m} = 1 - while (t2m != one()) { - t2m.self_sqr(); - i += 1; + // We want to compute g^{-e/2} which requires computing `e/2` via our slice representation + for (size_t i = 0; i < num_tables; ++i) { + auto& e_slice = e_slices[num_tables - 1 - i]; + // e_slices[num_tables - 1] is always even. + // From theorem 3.1 (https://cr.yp.to/papers/sqroot-20011123-retypeset20220327.pdf) + // if slice is odd, propagate the downshifted bit into previous slice value + if ((e_slice & 1UL) == 1UL) { + size_t borrow_value = (i == 1) ? 1UL << ((root_bits % table_bits) - 1) : (1UL << (table_bits - 1)); + e_slices[num_tables - i] += borrow_value; } + e_slice >>= 1; + } - size_t j = m - i - 1; - b = c; - while (j > 0) { - b.self_sqr(); - --j; - } // b = z^2^(m-i-1) - - c = b.sqr(); - t = t * c; - r = r * b; - m = i; + field g_pow_minus_e_over_2 = 1; + for (size_t i = 0; i < num_tables; ++i) { + if (i == 0) { + g_pow_minus_e_over_2 *= g_tables[i][e_slices[num_tables - 1 - i]]; + } else { + g_pow_minus_e_over_2 *= offset_g_tables[i - 1][e_slices[num_tables - 1 - i]]; + } } - return r; + return uv * g_pow_minus_e_over_2; } -template constexpr std::pair> field::sqrt() const noexcept +template +constexpr std::pair> field::sqrt() const noexcept + requires((T::modulus_0 & 0x3UL) == 0x3UL) { BB_OP_COUNT_TRACK_NAME("fr::sqrt"); - field root; - if constexpr ((T::modulus_0 & 0x3UL) == 0x3UL) { - constexpr uint256_t sqrt_exponent = (modulus + uint256_t(1)) >> 2; - root = pow(sqrt_exponent); - } else { - root = tonelli_shanks_sqrt(); - } + constexpr uint256_t sqrt_exponent = (modulus + uint256_t(1)) >> 2; + field root = pow(sqrt_exponent); if ((root * root) == (*this)) { return std::pair(true, root); } return std::pair(false, field::zero()); +} -} // namespace bb; +template +constexpr std::pair> field::sqrt() const noexcept + requires((T::modulus_0 & 0x3UL) != 0x3UL) +{ + field root = tonelli_shanks_sqrt(); + if ((root * root) == (*this)) { + return std::pair(true, root); + } + return std::pair(false, field::zero()); +} template constexpr field field::operator/(const field& other) const noexcept { @@ -634,8 +719,8 @@ constexpr std::array, field::COSET_GENERATOR_SIZE> field::compute size_t count = 1; while (count < n) { - // work_variable contains a new field element, and we need to test that, for all previous vector elements, - // result[i] / work_variable is not a member of our subgroup + // work_variable contains a new field element, and we need to test that, for all previous vector + // elements, result[i] / work_variable is not a member of our subgroup field work_inverse = work_variable.invert(); bool valid = true; for (size_t j = 0; j < count; ++j) { @@ -674,8 +759,9 @@ template void field::msgpack_pack(auto& packer) const // The field is first converted from Montgomery form, similar to how the old format did it. auto adjusted = from_montgomery_form(); - // The data is then converted to big endian format using htonll, which stands for "host to network long long". - // This is necessary because the data will be written to a raw msgpack buffer, which requires big endian format. + // The data is then converted to big endian format using htonll, which stands for "host to network long + // long". This is necessary because the data will be written to a raw msgpack buffer, which requires big + // endian format. uint64_t bin_data[4] = { htonll(adjusted.data[3]), htonll(adjusted.data[2]), htonll(adjusted.data[1]), htonll(adjusted.data[0]) }; @@ -693,8 +779,8 @@ template void field::msgpack_unpack(auto o) // The binary data is first extracted from the msgpack object. std::array raw_data = o; - // The binary data is then read as big endian uint64_t's. This is done by casting the raw data to uint64_t* and then - // using ntohll ("network to host long long") to correct the endianness to the host's endianness. + // The binary data is then read as big endian uint64_t's. This is done by casting the raw data to uint64_t* + // and then using ntohll ("network to host long long") to correct the endianness to the host's endianness. uint64_t* cast_data = (uint64_t*)&raw_data[0]; // NOLINT uint64_t reversed[] = { ntohll(cast_data[3]), ntohll(cast_data[2]), ntohll(cast_data[1]), ntohll(cast_data[0]) }; diff --git a/barretenberg/ts/src/barretenberg_wasm/barretenberg_wasm_main/index.ts b/barretenberg/ts/src/barretenberg_wasm/barretenberg_wasm_main/index.ts index 1be981669ae..6dc26bc0097 100644 --- a/barretenberg/ts/src/barretenberg_wasm/barretenberg_wasm_main/index.ts +++ b/barretenberg/ts/src/barretenberg_wasm/barretenberg_wasm_main/index.ts @@ -32,7 +32,7 @@ export class BarretenbergWasmMain extends BarretenbergWasmBase { module: WebAssembly.Module, threads = Math.min(getNumCpu(), BarretenbergWasmMain.MAX_THREADS), logger: (msg: string) => void = debug, - initial = 30, + initial = 31, maximum = 2 ** 16, ) { this.logger = logger; diff --git a/yarn-project/foundation/src/wasm/wasm_module.ts b/yarn-project/foundation/src/wasm/wasm_module.ts index f0c3a15debe..0b597f4b7bf 100644 --- a/yarn-project/foundation/src/wasm/wasm_module.ts +++ b/yarn-project/foundation/src/wasm/wasm_module.ts @@ -79,7 +79,7 @@ export class WasmModule implements IWasmModule { * @param initMethod - Defaults to calling '_initialize'. * @param maximum - 8192 maximum by default. 512mb. */ - public async init(initial = 30, maximum = 8192, initMethod: string | null = '_initialize') { + public async init(initial = 31, maximum = 8192, initMethod: string | null = '_initialize') { this.debug( `initial mem: ${initial} pages, ${(initial * 2 ** 16) / (1024 * 1024)}mb. max mem: ${maximum} pages, ${ (maximum * 2 ** 16) / (1024 * 1024)