From 382e5f6051cae8d25bf4617b744fdfeb820f19af Mon Sep 17 00:00:00 2001 From: Zachary James Williamson Date: Tue, 28 Nov 2023 15:58:23 +0000 Subject: [PATCH] feat: added poseidon2 hash function to barretenberg/crypto (#3118) Preliminary work to add Poseidon2 hash function as a standard library primitive (https://eprint.iacr.org/2023/323.pdf) Adds Poseidon2 to crypto module, following paper + specification at https://github.com/C2SP/C2SP/blob/792c1254124f625d459bfe34417e8f6bdd02eb28/poseidon-sponge.md --------- Co-authored-by: lucasxia01 --- cpp/src/barretenberg/crypto/CMakeLists.txt | 1 + .../crypto/poseidon2/CMakeLists.txt | 1 + .../crypto/poseidon2/poseidon2.bench.cpp | 29 + .../crypto/poseidon2/poseidon2.hpp | 16 + .../crypto/poseidon2/poseidon2.test.cpp | 49 ++ .../poseidon2/poseidon2_cpp_params.sage | 726 ++++++++++++++++++ .../crypto/poseidon2/poseidon2_params.hpp | 452 +++++++++++ .../poseidon2/poseidon2_permutation.hpp | 156 ++++ .../poseidon2/poseidon2_permutation.test.cpp | 65 ++ .../crypto/poseidon2/sponge/sponge.hpp | 168 ++++ .../dsl/acir_format/ecdsa_secp256r1.test.cpp | 2 +- .../ecc/fields/field_declarations.hpp | 6 + .../barretenberg/numeric/uint256/uint256.hpp | 44 +- .../numeric/uint256/uint256.test.cpp | 21 + 14 files changed, 1730 insertions(+), 6 deletions(-) create mode 100644 cpp/src/barretenberg/crypto/poseidon2/CMakeLists.txt create mode 100644 cpp/src/barretenberg/crypto/poseidon2/poseidon2.bench.cpp create mode 100644 cpp/src/barretenberg/crypto/poseidon2/poseidon2.hpp create mode 100644 cpp/src/barretenberg/crypto/poseidon2/poseidon2.test.cpp create mode 100644 cpp/src/barretenberg/crypto/poseidon2/poseidon2_cpp_params.sage create mode 100644 cpp/src/barretenberg/crypto/poseidon2/poseidon2_params.hpp create mode 100644 cpp/src/barretenberg/crypto/poseidon2/poseidon2_permutation.hpp create mode 100644 cpp/src/barretenberg/crypto/poseidon2/poseidon2_permutation.test.cpp create mode 100644 cpp/src/barretenberg/crypto/poseidon2/sponge/sponge.hpp diff --git a/cpp/src/barretenberg/crypto/CMakeLists.txt b/cpp/src/barretenberg/crypto/CMakeLists.txt index 87b519ba15..6efc1f8240 100644 --- a/cpp/src/barretenberg/crypto/CMakeLists.txt +++ b/cpp/src/barretenberg/crypto/CMakeLists.txt @@ -9,3 +9,4 @@ add_subdirectory(schnorr) add_subdirectory(sha256) add_subdirectory(ecdsa) add_subdirectory(aes128) +add_subdirectory(poseidon2) \ No newline at end of file diff --git a/cpp/src/barretenberg/crypto/poseidon2/CMakeLists.txt b/cpp/src/barretenberg/crypto/poseidon2/CMakeLists.txt new file mode 100644 index 0000000000..dc0157be3a --- /dev/null +++ b/cpp/src/barretenberg/crypto/poseidon2/CMakeLists.txt @@ -0,0 +1 @@ +barretenberg_module(crypto_poseidon2 ecc numeric) \ No newline at end of file diff --git a/cpp/src/barretenberg/crypto/poseidon2/poseidon2.bench.cpp b/cpp/src/barretenberg/crypto/poseidon2/poseidon2.bench.cpp new file mode 100644 index 0000000000..603238bf6e --- /dev/null +++ b/cpp/src/barretenberg/crypto/poseidon2/poseidon2.bench.cpp @@ -0,0 +1,29 @@ +#include "./poseidon2.hpp" +#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" +#include + +using namespace benchmark; + +grumpkin::fq poseidon_function(const size_t count) +{ + std::vector inputs(count); + for (size_t i = 0; i < count; ++i) { + inputs[i] = grumpkin::fq::random_element(); + } + std::span tmp(inputs); + // hash count many field elements + inputs[0] = crypto::Poseidon2::hash(tmp); + return inputs[0]; +} + +void native_poseidon2_commitment_bench(State& state) noexcept +{ + for (auto _ : state) { + const size_t count = (static_cast(state.range(0))); + (poseidon_function(count)); + } +} +BENCHMARK(native_poseidon2_commitment_bench)->Arg(10)->Arg(1000)->Arg(10000); + +BENCHMARK_MAIN(); +// } // namespace crypto \ No newline at end of file diff --git a/cpp/src/barretenberg/crypto/poseidon2/poseidon2.hpp b/cpp/src/barretenberg/crypto/poseidon2/poseidon2.hpp new file mode 100644 index 0000000000..15488e2d0b --- /dev/null +++ b/cpp/src/barretenberg/crypto/poseidon2/poseidon2.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "poseidon2_params.hpp" +#include "poseidon2_permutation.hpp" +#include "sponge/sponge.hpp" + +namespace crypto { + +template class Poseidon2 { + public: + using FF = typename Params::FF; + + using Sponge = FieldSponge>; + static FF hash(std::span input) { return Sponge::hash_fixed_length(input); } +}; +} // namespace crypto \ No newline at end of file diff --git a/cpp/src/barretenberg/crypto/poseidon2/poseidon2.test.cpp b/cpp/src/barretenberg/crypto/poseidon2/poseidon2.test.cpp new file mode 100644 index 0000000000..33757efd2b --- /dev/null +++ b/cpp/src/barretenberg/crypto/poseidon2/poseidon2.test.cpp @@ -0,0 +1,49 @@ +#include "poseidon2.hpp" +#include "barretenberg/crypto/poseidon2/poseidon2_params.hpp" +#include "barretenberg/ecc/curves/bn254/bn254.hpp" +#include + +using namespace barretenberg; + +namespace { +auto& engine = numeric::random::get_debug_engine(); +} + +namespace poseidon2_tests { +TEST(Poseidon2, BasicTests) +{ + + barretenberg::fr a = barretenberg::fr::random_element(&engine); + barretenberg::fr b = barretenberg::fr::random_element(&engine); + barretenberg::fr c = barretenberg::fr::random_element(&engine); + barretenberg::fr d = barretenberg::fr::random_element(&engine); + + std::vector input1{ a, b, c, d }; + std::vector input2{ d, c, b, a }; + + auto r0 = crypto::Poseidon2::hash(input1); + auto r1 = crypto::Poseidon2::hash(input1); + auto r2 = crypto::Poseidon2::hash(input2); + + EXPECT_EQ(r0, r1); + EXPECT_NE(r0, r2); +} + +// N.B. these hardcoded values were extracted from the algorithm being tested. These are NOT independent test vectors! +// TODO(@zac-williamson #3132): find independent test vectors we can compare against! (very hard to find given +// flexibility of Poseidon's parametrisation) +TEST(Poseidon2, ConsistencyCheck) +{ + barretenberg::fr a(std::string("9a807b615c4d3e2fa0b1c2d3e4f56789fedcba9876543210abcdef0123456789")); + barretenberg::fr b(std::string("9a807b615c4d3e2fa0b1c2d3e4f56789fedcba9876543210abcdef0123456789")); + barretenberg::fr c(std::string("0x9a807b615c4d3e2fa0b1c2d3e4f56789fedcba9876543210abcdef0123456789")); + barretenberg::fr d(std::string("0x9a807b615c4d3e2fa0b1c2d3e4f56789fedcba9876543210abcdef0123456789")); + + std::array input{ a, b, c, d }; + auto result = crypto::Poseidon2::hash(input); + + barretenberg::fr expected(std::string("0x150c19ae11b3290c137c7a4d760d9482a6581d731535f560c3601d6a766b0937")); + + EXPECT_EQ(result, expected); +} +} // namespace poseidon2_tests \ No newline at end of file diff --git a/cpp/src/barretenberg/crypto/poseidon2/poseidon2_cpp_params.sage b/cpp/src/barretenberg/crypto/poseidon2/poseidon2_cpp_params.sage new file mode 100644 index 0000000000..98eb6ab020 --- /dev/null +++ b/cpp/src/barretenberg/crypto/poseidon2/poseidon2_cpp_params.sage @@ -0,0 +1,726 @@ +# Remark: Original sage script authored by Markus Schofnegger from Horizen Labs +# Original source: https://github.com/HorizenLabs/poseidon2/blob/main/poseidon2_rust_params.sage +# Licenced under MIT. +# Remark: This script contains functionality for GF(2^n), but currently works only over GF(p)! A few small adaptations are needed for GF(2^n). +from sage.rings.polynomial.polynomial_gf2x import GF2X_BuildIrred_list +from math import * +import itertools + +########################################################################### +# p = 18446744069414584321 # GoldiLocks +# p = 2013265921 # BabyBear +# p = 52435875175126190479447740508185965837690552500527637822603658699938581184513 # BLS12-381 +p = 21888242871839275222246405745257275088548364400416034343698204186575808495617 # BN254/BN256 +# p = 28948022309329048855892746252171976963363056481941560715954676764349967630337 # Pasta (Pallas) +# p = 28948022309329048855892746252171976963363056481941647379679742748393362948097 # Pasta (Vesta) + +n = len(p.bits()) # bit +# t = 12 # GoldiLocks (t = 12 for sponge, t = 8 for compression) +# t = 16 # BabyBear (t = 24 for sponge, t = 16 for compression) +t = 4 # BN254/BN256, BLS12-381, Pallas, Vesta (t = 4 for sponge, t = 3 for compression) + +FIELD = 1 +SBOX = 0 +FIELD_SIZE = n +NUM_CELLS = t + +def get_alpha(p): + for alpha in range(3, p): + if gcd(alpha, p-1) == 1: + break + return alpha + +alpha = get_alpha(p) + +def get_sbox_cost(R_F, R_P, N, t): + return int(t * R_F + R_P) + +def get_size_cost(R_F, R_P, N, t): + n = ceil(float(N) / t) + return int((N * R_F) + (n * R_P)) + +def poseidon_calc_final_numbers_fixed(p, t, alpha, M, security_margin): + # [Min. S-boxes] Find best possible for t and N + n = ceil(log(p, 2)) + N = int(n * t) + cost_function = get_sbox_cost + ret_list = [] + (R_F, R_P) = find_FD_round_numbers(p, t, alpha, M, cost_function, security_margin) + min_sbox_cost = cost_function(R_F, R_P, N, t) + ret_list.append(R_F) + ret_list.append(R_P) + ret_list.append(min_sbox_cost) + + # [Min. Size] Find best possible for t and N + # Minimum number of S-boxes for fixed n results in minimum size also (round numbers are the same)! + min_size_cost = get_size_cost(R_F, R_P, N, t) + ret_list.append(min_size_cost) + + return ret_list # [R_F, R_P, min_sbox_cost, min_size_cost] + +def find_FD_round_numbers(p, t, alpha, M, cost_function, security_margin): + n = ceil(log(p, 2)) + N = int(n * t) + + sat_inequiv = sat_inequiv_alpha + + R_P = 0 + R_F = 0 + min_cost = float("inf") + max_cost_rf = 0 + # Brute-force approach + for R_P_t in range(1, 500): + for R_F_t in range(4, 100): + if R_F_t % 2 == 0: + if (sat_inequiv(p, t, R_F_t, R_P_t, alpha, M) == True): + if security_margin == True: + R_F_t += 2 + R_P_t = int(ceil(float(R_P_t) * 1.075)) + cost = cost_function(R_F_t, R_P_t, N, t) + if (cost < min_cost) or ((cost == min_cost) and (R_F_t < max_cost_rf)): + R_P = ceil(R_P_t) + R_F = ceil(R_F_t) + min_cost = cost + max_cost_rf = R_F + return (int(R_F), int(R_P)) + +def sat_inequiv_alpha(p, t, R_F, R_P, alpha, M): + N = int(FIELD_SIZE * NUM_CELLS) + + if alpha > 0: + R_F_1 = 6 if M <= ((floor(log(p, 2) - ((alpha-1)/2.0))) * (t + 1)) else 10 # Statistical + R_F_2 = 1 + ceil(log(2, alpha) * min(M, FIELD_SIZE)) + ceil(log(t, alpha)) - R_P # Interpolation + R_F_3 = (log(2, alpha) * min(M, log(p, 2))) - R_P # Groebner 1 + R_F_4 = t - 1 + log(2, alpha) * min(M / float(t + 1), log(p, 2) / float(2)) - R_P # Groebner 2 + R_F_5 = (t - 2 + (M / float(2 * log(alpha, 2))) - R_P) / float(t - 1) # Groebner 3 + R_F_max = max(ceil(R_F_1), ceil(R_F_2), ceil(R_F_3), ceil(R_F_4), ceil(R_F_5)) + + # Addition due to https://eprint.iacr.org/2023/537.pdf + r_temp = floor(t / 3.0) + over = (R_F - 1) * t + R_P + r_temp + r_temp * (R_F / 2.0) + R_P + alpha + under = r_temp * (R_F / 2.0) + R_P + alpha + binom_log = log(binomial(over, under), 2) + if binom_log == inf: + binom_log = M + 1 + cost_gb4 = ceil(2 * binom_log) # Paper uses 2.3727, we are more conservative here + + return ((R_F >= R_F_max) and (cost_gb4 >= M)) + else: + print("Invalid value for alpha!") + exit(1) + +R_F_FIXED, R_P_FIXED, _, _ = poseidon_calc_final_numbers_fixed(p, t, alpha, 128, True) +#print("+++ R_F = {0}, R_P = {1} +++".format(R_F_FIXED, R_P_FIXED)) + +# For STARK TODO +# r_p_mod = R_P_FIXED % NUM_CELLS +# if r_p_mod != 0: +# R_P_FIXED = R_P_FIXED + NUM_CELLS - r_p_mod + +########################################################################### + +INIT_SEQUENCE = [] + +PRIME_NUMBER = p +# if FIELD == 1 and len(sys.argv) != 8: +# print("Please specify a prime number (in hex format)!") +# exit() +# elif FIELD == 1 and len(sys.argv) == 8: +# PRIME_NUMBER = int(sys.argv[7], 16) # e.g. 0xa7, 0xFFFFFFFFFFFFFEFF, 0xa1a42c3efd6dbfe08daa6041b36322ef + +F = GF(PRIME_NUMBER) + +def grain_sr_generator(): + bit_sequence = INIT_SEQUENCE + for _ in range(0, 160): + new_bit = bit_sequence[62] ^^ bit_sequence[51] ^^ bit_sequence[38] ^^ bit_sequence[23] ^^ bit_sequence[13] ^^ bit_sequence[0] + bit_sequence.pop(0) + bit_sequence.append(new_bit) + + while True: + new_bit = bit_sequence[62] ^^ bit_sequence[51] ^^ bit_sequence[38] ^^ bit_sequence[23] ^^ bit_sequence[13] ^^ bit_sequence[0] + bit_sequence.pop(0) + bit_sequence.append(new_bit) + while new_bit == 0: + new_bit = bit_sequence[62] ^^ bit_sequence[51] ^^ bit_sequence[38] ^^ bit_sequence[23] ^^ bit_sequence[13] ^^ bit_sequence[0] + bit_sequence.pop(0) + bit_sequence.append(new_bit) + new_bit = bit_sequence[62] ^^ bit_sequence[51] ^^ bit_sequence[38] ^^ bit_sequence[23] ^^ bit_sequence[13] ^^ bit_sequence[0] + bit_sequence.pop(0) + bit_sequence.append(new_bit) + new_bit = bit_sequence[62] ^^ bit_sequence[51] ^^ bit_sequence[38] ^^ bit_sequence[23] ^^ bit_sequence[13] ^^ bit_sequence[0] + bit_sequence.pop(0) + bit_sequence.append(new_bit) + yield new_bit +grain_gen = grain_sr_generator() + +def grain_random_bits(num_bits): + random_bits = [next(grain_gen) for i in range(0, num_bits)] + # random_bits.reverse() ## Remove comment to start from least significant bit + random_int = int("".join(str(i) for i in random_bits), 2) + return random_int + +def init_generator(field, sbox, n, t, R_F, R_P): + # Generate initial sequence based on parameters + bit_list_field = [_ for _ in (bin(FIELD)[2:].zfill(2))] + bit_list_sbox = [_ for _ in (bin(SBOX)[2:].zfill(4))] + bit_list_n = [_ for _ in (bin(FIELD_SIZE)[2:].zfill(12))] + bit_list_t = [_ for _ in (bin(NUM_CELLS)[2:].zfill(12))] + bit_list_R_F = [_ for _ in (bin(R_F)[2:].zfill(10))] + bit_list_R_P = [_ for _ in (bin(R_P)[2:].zfill(10))] + bit_list_1 = [1] * 30 + global INIT_SEQUENCE + INIT_SEQUENCE = bit_list_field + bit_list_sbox + bit_list_n + bit_list_t + bit_list_R_F + bit_list_R_P + bit_list_1 + INIT_SEQUENCE = [int(_) for _ in INIT_SEQUENCE] + +def generate_constants(field, n, t, R_F, R_P, prime_number): + round_constants = [] + # num_constants = (R_F + R_P) * t # Poseidon + num_constants = (R_F * t) + R_P # Poseidon2 + + if field == 0: + for i in range(0, num_constants): + random_int = grain_random_bits(n) + round_constants.append(random_int) + elif field == 1: + for i in range(0, num_constants): + random_int = grain_random_bits(n) + while random_int >= prime_number: + # print("[Info] Round constant is not in prime field! Taking next one.") + random_int = grain_random_bits(n) + round_constants.append(random_int) + # Add (t-1) zeroes for Poseidon2 if partial round + if i >= ((R_F/2) * t) and i < (((R_F/2) * t) + R_P): + round_constants.extend([0] * (t-1)) + return round_constants + +def print_round_constants(round_constants, n, field): + print("Number of round constants:", len(round_constants)) + + if field == 0: + print("Round constants for GF(2^n):") + elif field == 1: + print("Round constants for GF(p):") + hex_length = int(ceil(float(n) / 4)) + 2 # +2 for "0x" + print(["{0:#0{1}x}".format(entry, hex_length) for entry in round_constants]) + +def create_mds_p(n, t): + M = matrix(F, t, t) + + # Sample random distinct indices and assign to xs and ys + while True: + flag = True + rand_list = [F(grain_random_bits(n)) for _ in range(0, 2*t)] + while len(rand_list) != len(set(rand_list)): # Check for duplicates + rand_list = [F(grain_random_bits(n)) for _ in range(0, 2*t)] + xs = rand_list[:t] + ys = rand_list[t:] + # xs = [F(ele) for ele in range(0, t)] + # ys = [F(ele) for ele in range(t, 2*t)] + for i in range(0, t): + for j in range(0, t): + if (flag == False) or ((xs[i] + ys[j]) == 0): + flag = False + else: + entry = (xs[i] + ys[j])^(-1) + M[i, j] = entry + if flag == False: + continue + return M + +def generate_vectorspace(round_num, M, M_round, NUM_CELLS): + t = NUM_CELLS + s = 1 + V = VectorSpace(F, t) + if round_num == 0: + return V + elif round_num == 1: + return V.subspace(V.basis()[s:]) + else: + mat_temp = matrix(F) + for i in range(0, round_num-1): + add_rows = [] + for j in range(0, s): + add_rows.append(M_round[i].rows()[j][s:]) + mat_temp = matrix(mat_temp.rows() + add_rows) + r_k = mat_temp.right_kernel() + extended_basis_vectors = [] + for vec in r_k.basis(): + extended_basis_vectors.append(vector([0]*s + list(vec))) + S = V.subspace(extended_basis_vectors) + + return S + +def subspace_times_matrix(subspace, M, NUM_CELLS): + t = NUM_CELLS + V = VectorSpace(F, t) + subspace_basis = subspace.basis() + new_basis = [] + for vec in subspace_basis: + new_basis.append(M * vec) + new_subspace = V.subspace(new_basis) + return new_subspace + +# Returns True if the matrix is considered secure, False otherwise +def algorithm_1(M, NUM_CELLS): + t = NUM_CELLS + s = 1 + r = floor((t - s) / float(s)) + + # Generate round matrices + M_round = [] + for j in range(0, t+1): + M_round.append(M^(j+1)) + + for i in range(1, r+1): + mat_test = M^i + entry = mat_test[0, 0] + mat_target = matrix.circulant(vector([entry] + ([F(0)] * (t-1)))) + + if (mat_test - mat_target) == matrix.circulant(vector([F(0)] * (t))): + return [False, 1] + + S = generate_vectorspace(i, M, M_round, t) + V = VectorSpace(F, t) + + basis_vectors= [] + for eigenspace in mat_test.eigenspaces_right(format='galois'): + if (eigenspace[0] not in F): + continue + vector_subspace = eigenspace[1] + intersection = S.intersection(vector_subspace) + basis_vectors += intersection.basis() + IS = V.subspace(basis_vectors) + + if IS.dimension() >= 1 and IS != V: + return [False, 2] + for j in range(1, i+1): + S_mat_mul = subspace_times_matrix(S, M^j, t) + if S == S_mat_mul: + print("S.basis():\n", S.basis()) + return [False, 3] + + return [True, 0] + +# Returns True if the matrix is considered secure, False otherwise +def algorithm_2(M, NUM_CELLS): + t = NUM_CELLS + s = 1 + + V = VectorSpace(F, t) + trail = [None, None] + test_next = False + I = range(0, s) + I_powerset = list(sage.misc.misc.powerset(I))[1:] + for I_s in I_powerset: + test_next = False + new_basis = [] + for l in I_s: + new_basis.append(V.basis()[l]) + IS = V.subspace(new_basis) + for i in range(s, t): + new_basis.append(V.basis()[i]) + full_iota_space = V.subspace(new_basis) + for l in I_s: + v = V.basis()[l] + while True: + delta = IS.dimension() + v = M * v + IS = V.subspace(IS.basis() + [v]) + if IS.dimension() == t or IS.intersection(full_iota_space) != IS: + test_next = True + break + if IS.dimension() <= delta: + break + if test_next == True: + break + if test_next == True: + continue + return [False, [IS, I_s]] + + return [True, None] + +# Returns True if the matrix is considered secure, False otherwise +def algorithm_3(M, NUM_CELLS): + t = NUM_CELLS + s = 1 + + V = VectorSpace(F, t) + + l = 4*t + for r in range(2, l+1): + next_r = False + res_alg_2 = algorithm_2(M^r, t) + if res_alg_2[0] == False: + return [False, None] + + # if res_alg_2[1] == None: + # continue + # IS = res_alg_2[1][0] + # I_s = res_alg_2[1][1] + # for j in range(1, r): + # IS = subspace_times_matrix(IS, M, t) + # I_j = [] + # for i in range(0, s): + # new_basis = [] + # for k in range(0, t): + # if k != i: + # new_basis.append(V.basis()[k]) + # iota_space = V.subspace(new_basis) + # if IS.intersection(iota_space) != iota_space: + # single_iota_space = V.subspace([V.basis()[i]]) + # if IS.intersection(single_iota_space) == single_iota_space: + # I_j.append(i) + # else: + # next_r = True + # break + # if next_r == True: + # break + # if next_r == True: + # continue + # return [False, [IS, I_j, r]] + + return [True, None] + +def check_minpoly_condition(M, NUM_CELLS): + max_period = 2*NUM_CELLS + all_fulfilled = True + M_temp = M + for i in range(1, max_period + 1): + if not ((M_temp.minimal_polynomial().degree() == NUM_CELLS) and (M_temp.minimal_polynomial().is_irreducible() == True)): + all_fulfilled = False + break + M_temp = M * M_temp + return all_fulfilled + +def generate_matrix(FIELD, FIELD_SIZE, NUM_CELLS): + if FIELD == 0: + print("Matrix generation not implemented for GF(2^n).") + exit(1) + elif FIELD == 1: + mds_matrix = create_mds_p(FIELD_SIZE, NUM_CELLS) + result_1 = algorithm_1(mds_matrix, NUM_CELLS) + result_2 = algorithm_2(mds_matrix, NUM_CELLS) + result_3 = algorithm_3(mds_matrix, NUM_CELLS) + while result_1[0] == False or result_2[0] == False or result_3[0] == False: + mds_matrix = create_mds_p(FIELD_SIZE, NUM_CELLS) + result_1 = algorithm_1(mds_matrix, NUM_CELLS) + result_2 = algorithm_2(mds_matrix, NUM_CELLS) + result_3 = algorithm_3(mds_matrix, NUM_CELLS) + return mds_matrix + +def generate_matrix_full(NUM_CELLS): + M = None + if t == 2: + M = matrix.circulant(vector([F(2), F(1)])) + elif t == 3: + M = matrix.circulant(vector([F(2), F(1), F(1)])) + elif t == 4: + M = matrix(F, [[F(5), F(7), F(1), F(3)], [F(4), F(6), F(1), F(1)], [F(1), F(3), F(5), F(7)], [F(1), F(1), F(4), F(6)]]) + elif (t % 4) == 0: + M = matrix(F, t, t) + # M_small = matrix.circulant(vector([F(3), F(2), F(1), F(1)])) + M_small = matrix(F, [[F(5), F(7), F(1), F(3)], [F(4), F(6), F(1), F(1)], [F(1), F(3), F(5), F(7)], [F(1), F(1), F(4), F(6)]]) + small_num = t // 4 + for i in range(0, small_num): + for j in range(0, small_num): + if i == j: + M[i*4:(i+1)*4,j*4:(j+1)*4] = 2* M_small + else: + M[i*4:(i+1)*4,j*4:(j+1)*4] = M_small + else: + print("Error: No matrix for these parameters.") + exit() + return M + +def generate_matrix_partial(FIELD, FIELD_SIZE, NUM_CELLS): ## TODO: Prioritize small entries + entry_max_bit_size = FIELD_SIZE + if FIELD == 0: + print("Matrix generation not implemented for GF(2^n).") + exit(1) + elif FIELD == 1: + M = None + if t == 2: + M = matrix(F, [[F(2), F(1)], [F(1), F(3)]]) + elif t == 3: + M = matrix(F, [[F(2), F(1), F(1)], [F(1), F(2), F(1)], [F(1), F(1), F(3)]]) + else: + M_circulant = matrix.circulant(vector([F(0)] + [F(1) for _ in range(0, NUM_CELLS - 1)])) + M_diagonal = matrix.diagonal([F(grain_random_bits(entry_max_bit_size)) for _ in range(0, NUM_CELLS)]) + M = M_circulant + M_diagonal + # while algorithm_1(M, NUM_CELLS)[0] == False or algorithm_2(M, NUM_CELLS)[0] == False or algorithm_3(M, NUM_CELLS)[0] == False: + while check_minpoly_condition(M, NUM_CELLS) == False: + M_diagonal = matrix.diagonal([F(grain_random_bits(entry_max_bit_size)) for _ in range(0, NUM_CELLS)]) + M = M_circulant + M_diagonal + + if(algorithm_1(M, NUM_CELLS)[0] == False or algorithm_2(M, NUM_CELLS)[0] == False or algorithm_3(M, NUM_CELLS)[0] == False): + print("Error: Generated partial matrix is not secure w.r.t. subspace trails.") + exit() + return M + +def generate_matrix_partial_small_entries(FIELD, FIELD_SIZE, NUM_CELLS): + if FIELD == 0: + print("Matrix generation not implemented for GF(2^n).") + exit(1) + elif FIELD == 1: + M_circulant = matrix.circulant(vector([F(0)] + [F(1) for _ in range(0, NUM_CELLS - 1)])) + combinations = list(itertools.product(range(2, 6), repeat=NUM_CELLS)) + for entry in combinations: + M = M_circulant + matrix.diagonal(vector(F, list(entry))) + print(M) + # if M.is_invertible() == False or algorithm_1(M, NUM_CELLS)[0] == False or algorithm_2(M, NUM_CELLS)[0] == False or algorithm_3(M, NUM_CELLS)[0] == False: + if M.is_invertible() == False or check_minpoly_condition(M, NUM_CELLS) == False: + continue + return M + +def matrix_partial_m_1(matrix_partial, NUM_CELLS): + M_circulant = matrix.identity(F, NUM_CELLS) + return matrix_partial - M_circulant + +def print_linear_layer(M, n, t): + print("n:", n) + print("t:", t) + print("N:", (n * t)) + print("Result Algorithm 1:\n", algorithm_1(M, NUM_CELLS)) + print("Result Algorithm 2:\n", algorithm_2(M, NUM_CELLS)) + print("Result Algorithm 3:\n", algorithm_3(M, NUM_CELLS)) + hex_length = int(ceil(float(n) / 4)) + 2 # +2 for "0x" + print("Prime number:", "0x" + hex(PRIME_NUMBER)) + matrix_string = "[" + for i in range(0, t): + matrix_string += str(["{0:#0{1}x}".format(int(entry), hex_length) for entry in M[i]]) + if i < (t-1): + matrix_string += "," + matrix_string += "]" + print("MDS matrix:\n", matrix_string) + +def calc_equivalent_matrices(MDS_matrix_field): + # Following idea: Split M into M' * M'', where M'' is "cheap" and M' can move before the partial nonlinear layer + # The "previous" matrix layer is then M * M'. Due to the construction of M', the M[0,0] and v values will be the same for the new M' (and I also, obviously) + # Thus: Compute the matrices, store the w_hat and v_hat values + + MDS_matrix_field_transpose = MDS_matrix_field.transpose() + + w_hat_collection = [] + v_collection = [] + v = MDS_matrix_field_transpose[[0], list(range(1,t))] + + M_mul = MDS_matrix_field_transpose + M_i = matrix(F, t, t) + for i in range(R_P_FIXED - 1, -1, -1): + M_hat = M_mul[list(range(1,t)), list(range(1,t))] + w = M_mul[list(range(1,t)), [0]] + v = M_mul[[0], list(range(1,t))] + v_collection.append(v.list()) + w_hat = M_hat.inverse() * w + w_hat_collection.append(w_hat.list()) + + # Generate new M_i, and multiplication M * M_i for "previous" round + M_i = matrix.identity(t) + M_i[list(range(1,t)), list(range(1,t))] = M_hat + M_mul = MDS_matrix_field_transpose * M_i + + return M_i, v_collection, w_hat_collection, MDS_matrix_field_transpose[0, 0] + +def calc_equivalent_constants(constants, MDS_matrix_field): + constants_temp = [constants[index:index+t] for index in range(0, len(constants), t)] + + MDS_matrix_field_transpose = MDS_matrix_field.transpose() + + # Start moving round constants up + # Calculate c_i' = M^(-1) * c_(i+1) + # Split c_i': Add c_i'[0] AFTER the S-box, add the rest to c_i + # I.e.: Store c_i'[0] for each of the partial rounds, and make c_i = c_i + c_i' (where now c_i'[0] = 0) + num_rounds = R_F_FIXED + R_P_FIXED + R_f = R_F_FIXED / 2 + for i in range(num_rounds - 2 - R_f, R_f - 1, -1): + inv_cip1 = list(vector(constants_temp[i+1]) * MDS_matrix_field_transpose.inverse()) + constants_temp[i] = list(vector(constants_temp[i]) + vector([0] + inv_cip1[1:])) + constants_temp[i+1] = [inv_cip1[0]] + [0] * (t-1) + + return constants_temp + +def poseidon(input_words, matrix, round_constants): + + R_f = int(R_F_FIXED / 2) + + round_constants_counter = 0 + + state_words = list(input_words) + + # First full rounds + for r in range(0, R_f): + # Round constants, nonlinear layer, matrix multiplication + for i in range(0, t): + state_words[i] = state_words[i] + round_constants[round_constants_counter] + round_constants_counter += 1 + for i in range(0, t): + state_words[i] = (state_words[i])^alpha + state_words = list(matrix * vector(state_words)) + + # Middle partial rounds + for r in range(0, R_P_FIXED): + # Round constants, nonlinear layer, matrix multiplication + for i in range(0, t): + state_words[i] = state_words[i] + round_constants[round_constants_counter] + round_constants_counter += 1 + state_words[0] = (state_words[0])^alpha + state_words = list(matrix * vector(state_words)) + + # Last full rounds + for r in range(0, R_f): + # Round constants, nonlinear layer, matrix multiplication + for i in range(0, t): + state_words[i] = state_words[i] + round_constants[round_constants_counter] + round_constants_counter += 1 + for i in range(0, t): + state_words[i] = (state_words[i])^alpha + state_words = list(matrix * vector(state_words)) + + return state_words + +def poseidon2(input_words, matrix_full, matrix_partial, round_constants): + + R_f = int(R_F_FIXED / 2) + + round_constants_counter = 0 + + state_words = list(input_words) + + # First matrix mul + state_words = list(matrix_full * vector(state_words)) + + # First full rounds + for r in range(0, R_f): + # Round constants, nonlinear layer, matrix multiplication + for i in range(0, t): + state_words[i] = state_words[i] + round_constants[round_constants_counter] + round_constants_counter += 1 + for i in range(0, t): + state_words[i] = (state_words[i])^alpha + state_words = list(matrix_full * vector(state_words)) + + # Middle partial rounds + for r in range(0, R_P_FIXED): + # Round constants, nonlinear layer, matrix multiplication + for i in range(0, t): + state_words[i] = state_words[i] + round_constants[round_constants_counter] + round_constants_counter += 1 + state_words[0] = (state_words[0])^alpha + state_words = list(matrix_partial * vector(state_words)) + + # Last full rounds + for r in range(0, R_f): + # Round constants, nonlinear layer, matrix multiplication + for i in range(0, t): + state_words[i] = state_words[i] + round_constants[round_constants_counter] + round_constants_counter += 1 + for i in range(0, t): + state_words[i] = (state_words[i])^alpha + state_words = list(matrix_full * vector(state_words)) + + return state_words + +# Init +init_generator(FIELD, SBOX, FIELD_SIZE, NUM_CELLS, R_F_FIXED, R_P_FIXED) + +# Round constants +round_constants = generate_constants(FIELD, FIELD_SIZE, NUM_CELLS, R_F_FIXED, R_P_FIXED, PRIME_NUMBER) +# print_round_constants(round_constants, FIELD_SIZE, FIELD) + +# Matrix +# MDS = generate_matrix(FIELD, FIELD_SIZE, NUM_CELLS) +MATRIX_FULL = generate_matrix_full(NUM_CELLS) +MATRIX_PARTIAL = generate_matrix_partial(FIELD, FIELD_SIZE, NUM_CELLS) +MATRIX_PARTIAL_DIAGONAL_M_1 = [matrix_partial_m_1(MATRIX_PARTIAL, NUM_CELLS)[i,i] for i in range(0, NUM_CELLS)] + +def to_hex(value): + l = len(hex(p - 1)) + if l % 2 == 1: + l = l + 1 + value = hex(int(value))[2:] + value = "0x" + value.zfill(l - 2) + print("FF(std::string(\"{}\")),".format(value)) + + + +# # MDS +# print("pub static ref MDS{}: Vec> = vec![".format(t)) +# for vec in MDS: +# print("vec![", end="") +# for val in vec: +# to_hex(val) +# print("],") +# print("];") +# print() + +print("// poseidon2 paramters generated via sage script") +print("// original author: Markus Schofnegger from Horizen Labs") +print("// original source: https://github.com/HorizenLabs/poseidon2/blob/main/poseidon2_rust_params.sage") +print("#pragma once\n") + +print("#include \"barretenberg/ecc/curves/bn254/fr.hpp\"\n") + +print("namespace crypto {\n") + +print("struct Poseidon2Bn254ScalarFieldParams{\n") +print(" using FF = barretenberg::fr;") +print(" static constexpr size_t t = {};".format(t)) +print(" static constexpr size_t d = {};".format(alpha)) + +print(" static constexpr size_t rounds_f = {};".format(R_F_FIXED)) +print(" static constexpr size_t rounds_p = {};".format(R_P_FIXED)) +print(" static constexpr size_t sbox_size = {};".format(FIELD_SIZE)) + +# Efficient partial matrix (diagonal - 1) +print("static constexpr std::array internal_matrix_diagonal = {") +for val in MATRIX_PARTIAL_DIAGONAL_M_1: + to_hex(val) +print("};") +print() + +# Efficient partial matrix (full) +print("static constexpr std::array, t> internal_matrix = {") +for vec in MATRIX_PARTIAL: + print("std::array{") + for val in vec: + to_hex(val) + print("},") +print("};") +print() + +# Round constants +print("static constexpr std::array, rounds_f + rounds_p> round_constants{") +for (i,val) in enumerate(round_constants): + if i % t == 0: + print("std::array{") + to_hex(val) + if i % t == t - 1: + print("},") +print("};") +print() + +#print("pub static ref POSEIDON_{}_PARAMS: Arc> = Arc::new(PoseidonParams::new({}, {}, {}, {}, &MAT_DIAG{}_M_1, &RC{}));".format(t, t, alpha, R_F_FIXED, R_P_FIXED , t, t)) + + + +state_in = vector([F(i) for i in range(t)]) +# state_out = poseidon(state_in, MDS, round_constants) +state_out = poseidon2(state_in, MATRIX_FULL, MATRIX_PARTIAL, round_constants) + +for (i,val) in enumerate(state_in): + if i % t == 0: + print("static constexpr std::array TEST_VECTOR_INPUT{") + to_hex(val) + if i % t == t - 1: + print("};") + +for (i,val) in enumerate(state_out): + if i % t == 0: + print("static constexpr std::array TEST_VECTOR_OUTPUT{") + to_hex(val) + if i % t == t - 1: + print("};") + +print("};") +print("} // namespace crypto") \ No newline at end of file diff --git a/cpp/src/barretenberg/crypto/poseidon2/poseidon2_params.hpp b/cpp/src/barretenberg/crypto/poseidon2/poseidon2_params.hpp new file mode 100644 index 0000000000..430d75f1fb --- /dev/null +++ b/cpp/src/barretenberg/crypto/poseidon2/poseidon2_params.hpp @@ -0,0 +1,452 @@ +// poseidon2 paramters generated via sage script +// original author: Markus Schofnegger from Horizen Labs +// original source: https://github.com/HorizenLabs/poseidon2/blob/main/poseidon2_rust_params.sage +#pragma once + +#include "barretenberg/ecc/curves/bn254/fr.hpp" + +namespace crypto { + +struct Poseidon2Bn254ScalarFieldParams { + + using FF = barretenberg::fr; + static constexpr size_t t = 4; + static constexpr size_t d = 5; + static constexpr size_t rounds_f = 8; + static constexpr size_t rounds_p = 56; + static constexpr size_t sbox_size = 254; + static constexpr std::array internal_matrix_diagonal = { + FF(std::string("0x10dc6e9c006ea38b04b1e03b4bd9490c0d03f98929ca1d7fb56821fd19d3b6e7")), + FF(std::string("0x0c28145b6a44df3e0149b3d0a30b3bb599df9756d4dd9b84a86b38cfb45a740b")), + FF(std::string("0x00544b8338791518b2c7645a50392798b21f75bb60e3596170067d00141cac15")), + FF(std::string("0x222c01175718386f2e2e82eb122789e352e105a3b8fa852613bc534433ee428b")), + }; + + static constexpr std::array, t> internal_matrix = { + std::array{ + FF(std::string("0x10dc6e9c006ea38b04b1e03b4bd9490c0d03f98929ca1d7fb56821fd19d3b6e8")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000001")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000001")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000001")), + }, + std::array{ + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000001")), + FF(std::string("0x0c28145b6a44df3e0149b3d0a30b3bb599df9756d4dd9b84a86b38cfb45a740c")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000001")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000001")), + }, + std::array{ + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000001")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000001")), + FF(std::string("0x00544b8338791518b2c7645a50392798b21f75bb60e3596170067d00141cac16")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000001")), + }, + std::array{ + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000001")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000001")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000001")), + FF(std::string("0x222c01175718386f2e2e82eb122789e352e105a3b8fa852613bc534433ee428c")), + }, + }; + + static constexpr std::array, rounds_f + rounds_p> round_constants{ + std::array{ + FF(std::string("0x19b849f69450b06848da1d39bd5e4a4302bb86744edc26238b0878e269ed23e5")), + FF(std::string("0x265ddfe127dd51bd7239347b758f0a1320eb2cc7450acc1dad47f80c8dcf34d6")), + FF(std::string("0x199750ec472f1809e0f66a545e1e51624108ac845015c2aa3dfc36bab497d8aa")), + FF(std::string("0x157ff3fe65ac7208110f06a5f74302b14d743ea25067f0ffd032f787c7f1cdf8")), + }, + std::array{ + FF(std::string("0x2e49c43c4569dd9c5fd35ac45fca33f10b15c590692f8beefe18f4896ac94902")), + FF(std::string("0x0e35fb89981890520d4aef2b6d6506c3cb2f0b6973c24fa82731345ffa2d1f1e")), + FF(std::string("0x251ad47cb15c4f1105f109ae5e944f1ba9d9e7806d667ffec6fe723002e0b996")), + FF(std::string("0x13da07dc64d428369873e97160234641f8beb56fdd05e5f3563fa39d9c22df4e")), + }, + std::array{ + FF(std::string("0x0c009b84e650e6d23dc00c7dccef7483a553939689d350cd46e7b89055fd4738")), + FF(std::string("0x011f16b1c63a854f01992e3956f42d8b04eb650c6d535eb0203dec74befdca06")), + FF(std::string("0x0ed69e5e383a688f209d9a561daa79612f3f78d0467ad45485df07093f367549")), + FF(std::string("0x04dba94a7b0ce9e221acad41472b6bbe3aec507f5eb3d33f463672264c9f789b")), + }, + std::array{ + FF(std::string("0x0a3f2637d840f3a16eb094271c9d237b6036757d4bb50bf7ce732ff1d4fa28e8")), + FF(std::string("0x259a666f129eea198f8a1c502fdb38fa39b1f075569564b6e54a485d1182323f")), + FF(std::string("0x28bf7459c9b2f4c6d8e7d06a4ee3a47f7745d4271038e5157a32fdf7ede0d6a1")), + FF(std::string("0x0a1ca941f057037526ea200f489be8d4c37c85bbcce6a2aeec91bd6941432447")), + }, + std::array{ + FF(std::string("0x0c6f8f958be0e93053d7fd4fc54512855535ed1539f051dcb43a26fd926361cf")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x123106a93cd17578d426e8128ac9d90aa9e8a00708e296e084dd57e69caaf811")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x26e1ba52ad9285d97dd3ab52f8e840085e8fa83ff1e8f1877b074867cd2dee75")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x1cb55cad7bd133de18a64c5c47b9c97cbe4d8b7bf9e095864471537e6a4ae2c5")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x1dcd73e46acd8f8e0e2c7ce04bde7f6d2a53043d5060a41c7143f08e6e9055d0")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x011003e32f6d9c66f5852f05474a4def0cda294a0eb4e9b9b12b9bb4512e5574")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x2b1e809ac1d10ab29ad5f20d03a57dfebadfe5903f58bafed7c508dd2287ae8c")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x2539de1785b735999fb4dac35ee17ed0ef995d05ab2fc5faeaa69ae87bcec0a5")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x0c246c5a2ef8ee0126497f222b3e0a0ef4e1c3d41c86d46e43982cb11d77951d")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x192089c4974f68e95408148f7c0632edbb09e6a6ad1a1c2f3f0305f5d03b527b")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x1eae0ad8ab68b2f06a0ee36eeb0d0c058529097d91096b756d8fdc2fb5a60d85")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x179190e5d0e22179e46f8282872abc88db6e2fdc0dee99e69768bd98c5d06bfb")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x29bb9e2c9076732576e9a81c7ac4b83214528f7db00f31bf6cafe794a9b3cd1c")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x225d394e42207599403efd0c2464a90d52652645882aac35b10e590e6e691e08")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x064760623c25c8cf753d238055b444532be13557451c087de09efd454b23fd59")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x10ba3a0e01df92e87f301c4b716d8a394d67f4bf42a75c10922910a78f6b5b87")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x0e070bf53f8451b24f9c6e96b0c2a801cb511bc0c242eb9d361b77693f21471c")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x1b94cd61b051b04dd39755ff93821a73ccd6cb11d2491d8aa7f921014de252fb")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x1d7cb39bafb8c744e148787a2e70230f9d4e917d5713bb050487b5aa7d74070b")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x2ec93189bd1ab4f69117d0fe980c80ff8785c2961829f701bb74ac1f303b17db")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x2db366bfdd36d277a692bb825b86275beac404a19ae07a9082ea46bd83517926")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x062100eb485db06269655cf186a68532985275428450359adc99cec6960711b8")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x0761d33c66614aaa570e7f1e8244ca1120243f92fa59e4f900c567bf41f5a59b")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x20fc411a114d13992c2705aa034e3f315d78608a0f7de4ccf7a72e494855ad0d")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x25b5c004a4bdfcb5add9ec4e9ab219ba102c67e8b3effb5fc3a30f317250bc5a")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x23b1822d278ed632a494e58f6df6f5ed038b186d8474155ad87e7dff62b37f4b")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x22734b4c5c3f9493606c4ba9012499bf0f14d13bfcfcccaa16102a29cc2f69e0")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x26c0c8fe09eb30b7e27a74dc33492347e5bdff409aa3610254413d3fad795ce5")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x070dd0ccb6bd7bbae88eac03fa1fbb26196be3083a809829bbd626df348ccad9")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x12b6595bdb329b6fb043ba78bb28c3bec2c0a6de46d8c5ad6067c4ebfd4250da")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x248d97d7f76283d63bec30e7a5876c11c06fca9b275c671c5e33d95bb7e8d729")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x1a306d439d463b0816fc6fd64cc939318b45eb759ddde4aa106d15d9bd9baaaa")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x28a8f8372e3c38daced7c00421cb4621f4f1b54ddc27821b0d62d3d6ec7c56cf")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x0094975717f9a8a8bb35152f24d43294071ce320c829f388bc852183e1e2ce7e")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x04d5ee4c3aa78f7d80fde60d716480d3593f74d4f653ae83f4103246db2e8d65")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x2a6cf5e9aa03d4336349ad6fb8ed2269c7bef54b8822cc76d08495c12efde187")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x2304d31eaab960ba9274da43e19ddeb7f792180808fd6e43baae48d7efcba3f3")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x03fd9ac865a4b2a6d5e7009785817249bff08a7e0726fcb4e1c11d39d199f0b0")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x00b7258ded52bbda2248404d55ee5044798afc3a209193073f7954d4d63b0b64")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x159f81ada0771799ec38fca2d4bf65ebb13d3a74f3298db36272c5ca65e92d9a")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x1ef90e67437fbc8550237a75bc28e3bb9000130ea25f0c5471e144cf4264431f")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x1e65f838515e5ff0196b49aa41a2d2568df739bc176b08ec95a79ed82932e30d")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x2b1b045def3a166cec6ce768d079ba74b18c844e570e1f826575c1068c94c33f")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x0832e5753ceb0ff6402543b1109229c165dc2d73bef715e3f1c6e07c168bb173")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x02f614e9cedfb3dc6b762ae0a37d41bab1b841c2e8b6451bc5a8e3c390b6ad16")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x0e2427d38bd46a60dd640b8e362cad967370ebb777bedff40f6a0be27e7ed705")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x0493630b7c670b6deb7c84d414e7ce79049f0ec098c3c7c50768bbe29214a53a")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x22ead100e8e482674decdab17066c5a26bb1515355d5461a3dc06cc85327cea9")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x25b3e56e655b42cdaae2626ed2554d48583f1ae35626d04de5084e0b6d2a6f16")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x1e32752ada8836ef5837a6cde8ff13dbb599c336349e4c584b4fdc0a0cf6f9d0")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x2fa2a871c15a387cc50f68f6f3c3455b23c00995f05078f672a9864074d412e5")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x2f569b8a9a4424c9278e1db7311e889f54ccbf10661bab7fcd18e7c7a7d83505")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x044cb455110a8fdd531ade530234c518a7df93f7332ffd2144165374b246b43d")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x227808de93906d5d420246157f2e42b191fe8c90adfe118178ddc723a5319025")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x02fcca2934e046bc623adead873579865d03781ae090ad4a8579d2e7a6800355")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x0ef915f0ac120b876abccceb344a1d36bad3f3c5ab91a8ddcbec2e060d8befac")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + }, + std::array{ + FF(std::string("0x1797130f4b7a3e1777eb757bc6f287f6ab0fb85f6be63b09f3b16ef2b1405d38")), + FF(std::string("0x0a76225dc04170ae3306c85abab59e608c7f497c20156d4d36c668555decc6e5")), + FF(std::string("0x1fffb9ec1992d66ba1e77a7b93209af6f8fa76d48acb664796174b5326a31a5c")), + FF(std::string("0x25721c4fc15a3f2853b57c338fa538d85f8fbba6c6b9c6090611889b797b9c5f")), + }, + std::array{ + FF(std::string("0x0c817fd42d5f7a41215e3d07ba197216adb4c3790705da95eb63b982bfcaf75a")), + FF(std::string("0x13abe3f5239915d39f7e13c2c24970b6df8cf86ce00a22002bc15866e52b5a96")), + FF(std::string("0x2106feea546224ea12ef7f39987a46c85c1bc3dc29bdbd7a92cd60acb4d391ce")), + FF(std::string("0x21ca859468a746b6aaa79474a37dab49f1ca5a28c748bc7157e1b3345bb0f959")), + }, + std::array{ + FF(std::string("0x05ccd6255c1e6f0c5cf1f0df934194c62911d14d0321662a8f1a48999e34185b")), + FF(std::string("0x0f0e34a64b70a626e464d846674c4c8816c4fb267fe44fe6ea28678cb09490a4")), + FF(std::string("0x0558531a4e25470c6157794ca36d0e9647dbfcfe350d64838f5b1a8a2de0d4bf")), + FF(std::string("0x09d3dca9173ed2faceea125157683d18924cadad3f655a60b72f5864961f1455")), + }, + std::array{ + FF(std::string("0x0328cbd54e8c0913493f866ed03d218bf23f92d68aaec48617d4c722e5bd4335")), + FF(std::string("0x2bf07216e2aff0a223a487b1a7094e07e79e7bcc9798c648ee3347dd5329d34b")), + FF(std::string("0x1daf345a58006b736499c583cb76c316d6f78ed6a6dffc82111e11a63fe412df")), + FF(std::string("0x176563472456aaa746b694c60e1823611ef39039b2edc7ff391e6f2293d2c404")), + }, + }; + + static constexpr std::array TEST_VECTOR_INPUT{ + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000000")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000001")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000002")), + FF(std::string("0x0000000000000000000000000000000000000000000000000000000000000003")), + }; + static constexpr std::array TEST_VECTOR_OUTPUT{ + FF(std::string("0x01bd538c2ee014ed5141b29e9ae240bf8db3fe5b9a38629a9647cf8d76c01737")), + FF(std::string("0x239b62e7db98aa3a2a8f6a0d2fa1709e7a35959aa6c7034814d9daa90cbac662")), + FF(std::string("0x04cbb44c61d928ed06808456bf758cbf0c18d1e15a7b6dbc8245fa7515d5e3cb")), + FF(std::string("0x2e11c5cff2a22c64d01304b778d78f6998eff1ab73163a35603f54794c30847a")), + }; +}; +} // namespace crypto diff --git a/cpp/src/barretenberg/crypto/poseidon2/poseidon2_permutation.hpp b/cpp/src/barretenberg/crypto/poseidon2/poseidon2_permutation.hpp new file mode 100644 index 0000000000..40606b4887 --- /dev/null +++ b/cpp/src/barretenberg/crypto/poseidon2/poseidon2_permutation.hpp @@ -0,0 +1,156 @@ +#pragma once + +#include "poseidon2_params.hpp" + +#include "barretenberg/common/throw_or_abort.hpp" + +#include +#include +#include +#include + +namespace crypto { + +/** + * @brief Applies the Poseidon2 permutation function from https://eprint.iacr.org/2023/323 . + * This algorithm was implemented using https://github.com/HorizenLabs/poseidon2 as a reference. + * + * @tparam Params + */ +template class Poseidon2Permutation { + public: + // t = sponge permutation size (in field elements) + // t = rate + capacity + // capacity = 1 field element (256 bits) + // rate = number of field elements that can be compressed per permutation + static constexpr size_t t = Params::t; + // d = degree of s-box polynomials. For a given field, `d` is the smallest element of `p` such that gdc(d, p - 1) = + // 1 (excluding 1) For bn254/grumpkin, d = 5 + static constexpr size_t d = Params::d; + // sbox size = number of bits in p + static constexpr size_t sbox_size = Params::sbox_size; + // number of full sbox rounds + static constexpr size_t rounds_f = Params::rounds_f; + // number of partial sbox rounds + static constexpr size_t rounds_p = Params::rounds_p; + static constexpr size_t NUM_ROUNDS = Params::rounds_f + Params::rounds_p; + + using FF = typename Params::FF; + using State = std::array; + using RoundConstants = std::array; + using MatrixDiagonal = std::array; + using RoundConstantsContainer = std::array; + + static constexpr MatrixDiagonal internal_matrix_diagonal = + Poseidon2Bn254ScalarFieldParams::internal_matrix_diagonal; + static constexpr RoundConstantsContainer round_constants = Poseidon2Bn254ScalarFieldParams::round_constants; + + static constexpr void matrix_multiplication_4x4(State& input) + { + /** + * hardcoded algorithm that evaluates matrix multiplication using the following MDS matrix: + * / \ + * | 5 7 1 3 | + * | 4 6 1 1 | + * | 1 3 5 7 | + * | 1 1 4 6 | + * \ / + * + * Algorithm is taken directly from the Poseidon2 paper. + */ + auto t0 = input[0] + input[1]; // A + B + auto t1 = input[2] + input[3]; // C + D + auto t2 = input[1] + input[1]; // 2B + t2 += t1; // 2B + C + D + auto t3 = input[3] + input[3]; // 2D + t3 += t0; // 2D + A + B + auto t4 = t1 + t1; + t4 += t4; + t4 += t3; // A + B + 4C + 6D + auto t5 = t0 + t0; + t5 += t5; + t5 += t2; // 4A + 6B + C + D + auto t6 = t3 + t5; // 5A + 7B + 3C + D + auto t7 = t2 + t4; // A + 3B + 5D + 7C + input[0] = t6; + input[1] = t5; + input[2] = t7; + input[3] = t4; + } + + static constexpr void add_round_constants(State& input, const RoundConstants& rc) + { + for (size_t i = 0; i < t; ++i) { + input[i] += rc[i]; + } + } + + static constexpr void matrix_multiplication_internal(State& input) + { + // for t = 4 + auto sum = input[0]; + for (size_t i = 1; i < t; ++i) { + sum += input[i]; + } + for (size_t i = 0; i < t; ++i) { + input[i] *= internal_matrix_diagonal[i]; + input[i] += sum; + } + } + + static constexpr void matrix_multiplication_external(State& input) + { + if constexpr (t == 4) { + matrix_multiplication_4x4(input); + } else { + // erm panic + throw_or_abort("not supported"); + } + } + + static constexpr void apply_single_sbox(FF& input) + { + // hardcoded assumption that d = 5. should fix this or not make d configurable + auto xx = input.sqr(); + auto xxxx = xx.sqr(); + input *= xxxx; + } + + static constexpr void apply_sbox(State& input) + { + for (auto& in : input) { + apply_single_sbox(in); + } + } + + static constexpr State permutation(const State& input) + { + // deep copy + State current_state(input); + + // Apply 1st linear layer + matrix_multiplication_external(current_state); + + constexpr size_t rounds_f_beginning = rounds_f / 2; + for (size_t i = 0; i < rounds_f_beginning; ++i) { + add_round_constants(current_state, round_constants[i]); + apply_sbox(current_state); + matrix_multiplication_external(current_state); + } + + const size_t p_end = rounds_f_beginning + rounds_p; + for (size_t i = rounds_f_beginning; i < p_end; ++i) { + current_state[0] += round_constants[i][0]; + apply_single_sbox(current_state[0]); + matrix_multiplication_internal(current_state); + } + + for (size_t i = p_end; i < NUM_ROUNDS; ++i) { + add_round_constants(current_state, round_constants[i]); + apply_sbox(current_state); + matrix_multiplication_external(current_state); + } + return current_state; + } +}; +} // namespace crypto \ No newline at end of file diff --git a/cpp/src/barretenberg/crypto/poseidon2/poseidon2_permutation.test.cpp b/cpp/src/barretenberg/crypto/poseidon2/poseidon2_permutation.test.cpp new file mode 100644 index 0000000000..759a00c768 --- /dev/null +++ b/cpp/src/barretenberg/crypto/poseidon2/poseidon2_permutation.test.cpp @@ -0,0 +1,65 @@ +#include "poseidon2_permutation.hpp" +#include "barretenberg/crypto/poseidon2/poseidon2_params.hpp" +#include "barretenberg/ecc/curves/bn254/bn254.hpp" +#include + +using namespace barretenberg; + +namespace { +auto& engine = numeric::random::get_debug_engine(); +} + +namespace poseidon2_tests { + +TEST(Poseidon2Permutation, TestVectors) +{ + + auto input = crypto::Poseidon2Bn254ScalarFieldParams::TEST_VECTOR_INPUT; + auto expected = crypto::Poseidon2Bn254ScalarFieldParams::TEST_VECTOR_OUTPUT; + auto result = crypto::Poseidon2Permutation::permutation(input); + + EXPECT_EQ(result, expected); +} + +TEST(Poseidon2Permutation, BasicTests) +{ + + barretenberg::fr a = barretenberg::fr::random_element(&engine); + barretenberg::fr b = barretenberg::fr::random_element(&engine); + barretenberg::fr c = barretenberg::fr::random_element(&engine); + barretenberg::fr d = barretenberg::fr::random_element(&engine); + + std::array input1{ a, b, c, d }; + std::array input2{ d, c, b, a }; + + auto r0 = crypto::Poseidon2Permutation::permutation(input1); + auto r1 = crypto::Poseidon2Permutation::permutation(input1); + auto r2 = crypto::Poseidon2Permutation::permutation(input2); + + EXPECT_EQ(r0, r1); + EXPECT_NE(r0, r2); +} + +// N.B. these hardcoded values were extracted from the algorithm being tested. These are NOT independent test vectors! +// TODO(@zac-williamson #3132): find independent test vectors we can compare against! (very hard to find given +// flexibility of Poseidon's parametrisation) +TEST(Poseidon2Permutation, ConsistencyCheck) +{ + barretenberg::fr a(std::string("9a807b615c4d3e2fa0b1c2d3e4f56789fedcba9876543210abcdef0123456789")); + barretenberg::fr b(std::string("9a807b615c4d3e2fa0b1c2d3e4f56789fedcba9876543210abcdef0123456789")); + barretenberg::fr c(std::string("0x9a807b615c4d3e2fa0b1c2d3e4f56789fedcba9876543210abcdef0123456789")); + barretenberg::fr d(std::string("0x9a807b615c4d3e2fa0b1c2d3e4f56789fedcba9876543210abcdef0123456789")); + + std::array input{ a, b, c, d }; + auto result = crypto::Poseidon2Permutation::permutation(input); + + std::array expected{ + barretenberg::fr(std::string("0x2bf1eaf87f7d27e8dc4056e9af975985bccc89077a21891d6c7b6ccce0631f95")), + barretenberg::fr(std::string("0x0c01fa1b8d0748becafbe452c0cb0231c38224ea824554c9362518eebdd5701f")), + barretenberg::fr(std::string("0x018555a8eb50cf07f64b019ebaf3af3c925c93e631f3ecd455db07bbb52bbdd3")), + barretenberg::fr(std::string("0x0cbea457c91c22c6c31fd89afd2541efc2edf31736b9f721e823b2165c90fd41")), + }; + EXPECT_EQ(result, expected); +} + +} // namespace poseidon2_tests \ No newline at end of file diff --git a/cpp/src/barretenberg/crypto/poseidon2/sponge/sponge.hpp b/cpp/src/barretenberg/crypto/poseidon2/sponge/sponge.hpp new file mode 100644 index 0000000000..bc4a2c5c1d --- /dev/null +++ b/cpp/src/barretenberg/crypto/poseidon2/sponge/sponge.hpp @@ -0,0 +1,168 @@ +#pragma once + +#include +#include +#include +#include + +#include "barretenberg/numeric/uint256/uint256.hpp" + +namespace crypto { + +/** + * @brief Implements a cryptographic sponge over prime fields. + * Implements the sponge specification from the Community Cryptographic Specification Project + * see https://github.com/C2SP/C2SP/blob/792c1254124f625d459bfe34417e8f6bdd02eb28/poseidon-sponge.md + * (Note: this spec was not accepted into the C2SP repo, we might want to reference something else!) + * + * Note: If we ever use this sponge class for more than 1 hash functions, we should move this out of `poseidon2` + * and into its own directory + * @tparam FF + * @tparam rate + * @tparam capacity + * @tparam t + * @tparam Permutation + */ +template class FieldSponge { + public: + /** + * @brief Defines what phase of the sponge algorithm we are in. + * + * ABSORB: 'absorbing' field elements into the sponge + * SQUEEZE: compressing the sponge and extracting a field element + * + */ + enum Mode { + ABSORB, + SQUEEZE, + }; + + // sponge state. t = rate + capacity. capacity = 1 field element (~256 bits) + std::array state; + + // cached elements that have been absorbed. + std::array cache; + size_t cache_size = 0; + Mode mode = Mode::ABSORB; + + FieldSponge(FF domain_iv = 0) + { + for (size_t i = 0; i < rate; ++i) { + state[i] = 0; + } + state[rate] = domain_iv; + } + + std::array perform_duplex() + { + // zero-pad the cache + for (size_t i = cache_size; i < rate; ++i) { + cache[i] = 0; + } + // add the cache into sponge state + for (size_t i = 0; i < rate; ++i) { + state[i] += cache[i]; + } + state = Permutation::permutation(state); + // return `rate` number of field elements from the sponge state. + std::array output; + for (size_t i = 0; i < rate; ++i) { + output[i] = state[i]; + } + return output; + } + + void absorb(const FF& input) + { + if (mode == Mode::ABSORB && cache_size == rate) { + // If we're absorbing, and the cache is full, apply the sponge permutation to compress the cache + perform_duplex(); + cache[0] = input; + cache_size = 1; + } else if (mode == Mode::ABSORB && cache_size < rate) { + // If we're absorbing, and the cache is not full, add the input into the cache + cache[cache_size] = input; + cache_size += 1; + } else if (mode == Mode::SQUEEZE) { + // If we're in squeeze mode, switch to absorb mode and add the input into the cache. + // N.B. I don't think this code path can be reached?! + cache[0] = input; + cache_size = 1; + mode = Mode::ABSORB; + } + } + + FF squeeze() + { + if (mode == Mode::SQUEEZE && cache_size == 0) { + // If we're in squeze mode and the cache is empty, there is nothing left to squeeze out of the sponge! + // Switch to absorb mode. + mode = Mode::ABSORB; + cache_size = 0; + } + if (mode == Mode::ABSORB) { + // If we're in absorb mode, apply sponge permutation to compress the cache, populate cache with compressed + // state and switch to squeeze mode. Note: this code block will execute if the previous `if` condition was + // matched + auto new_output_elements = perform_duplex(); + mode = Mode::SQUEEZE; + for (size_t i = 0; i < rate; ++i) { + cache[i] = new_output_elements[i]; + } + cache_size = rate; + } + // By this point, we should have a non-empty cache. Pop one item off the top of the cache and return it. + FF result = cache[0]; + for (size_t i = 1; i < cache_size; ++i) { + cache[i - 1] = cache[i]; + } + cache_size -= 1; + cache[cache_size] = 0; + return result; + } + + /** + * @brief Use the sponge to hash an input string + * + * @tparam out_len + * @tparam is_variable_length. Distinguishes between hashes where the preimage length is constant/not constant + * @param input + * @return std::array + */ + template static std::array hash_internal(std::span input) + { + size_t in_len = input.size(); + const uint256_t iv = (static_cast(in_len) << 64) + out_len - 1; + FieldSponge sponge(iv); + + for (size_t i = 0; i < in_len; ++i) { + sponge.absorb(input[i]); + } + + // In the case where the hash preimage is variable-length, we append `1` to the end of the input, to distinguish + // from fixed-length hashes. (the combination of this additional field element + the hash IV ensures + // fixed-length and variable-length hashes do not collide) + if constexpr (is_variable_length) { + sponge.absorb(1); + } + + std::array output; + for (size_t i = 0; i < out_len; ++i) { + output[i] = sponge.squeeze(); + } + return output; + } + + template static std::array hash_fixed_length(std::span input) + { + return hash_internal(input); + } + static FF hash_fixed_length(std::span input) { return hash_fixed_length<1>(input)[0]; } + + template static std::array hash_variable_length(std::span input) + { + return hash_internal(input); + } + static FF hash_variable_length(std::span input) { return hash_variable_length<1>(input)[0]; } +}; +} // namespace crypto \ No newline at end of file diff --git a/cpp/src/barretenberg/dsl/acir_format/ecdsa_secp256r1.test.cpp b/cpp/src/barretenberg/dsl/acir_format/ecdsa_secp256r1.test.cpp index 01adec035f..0f8764c3bc 100644 --- a/cpp/src/barretenberg/dsl/acir_format/ecdsa_secp256r1.test.cpp +++ b/cpp/src/barretenberg/dsl/acir_format/ecdsa_secp256r1.test.cpp @@ -114,7 +114,7 @@ TEST(ECDSASecp256r1, test_hardcoded) }; crypto::ecdsa::key_pair account; - account.private_key = curve_ct::fr(uint256_t("020202020202020202020202020202020202020202020202020202020202020202")); + account.private_key = curve_ct::fr(uint256_t("0202020202020202020202020202020202020202020202020202020202020202")); account.public_key = curve_ct::g1::one * account.private_key; diff --git a/cpp/src/barretenberg/ecc/fields/field_declarations.hpp b/cpp/src/barretenberg/ecc/fields/field_declarations.hpp index c64b1e3519..a0446f1791 100644 --- a/cpp/src/barretenberg/ecc/fields/field_declarations.hpp +++ b/cpp/src/barretenberg/ecc/fields/field_declarations.hpp @@ -106,6 +106,12 @@ template struct alignas(32) field { self_to_montgomery_form(); } + constexpr explicit field(std::string input) noexcept + { + uint256_t value(input); + *this = field(value); + } + constexpr explicit operator uint32_t() const { field out = from_montgomery_form(); diff --git a/cpp/src/barretenberg/numeric/uint256/uint256.hpp b/cpp/src/barretenberg/numeric/uint256/uint256.hpp index 70c80c811a..d8cd9ef2f1 100644 --- a/cpp/src/barretenberg/numeric/uint256/uint256.hpp +++ b/cpp/src/barretenberg/numeric/uint256/uint256.hpp @@ -13,6 +13,7 @@ #include "../uint128/uint128.hpp" #include "barretenberg/common/serialize.hpp" +#include "barretenberg/common/throw_or_abort.hpp" #include #include #include @@ -35,13 +36,46 @@ class alignas(32) uint256_t { {} constexpr uint256_t(uint256_t&& other) noexcept = default; - explicit uint256_t(std::string const& str) noexcept + explicit constexpr uint256_t(std::string input) noexcept { - for (int i = 0; i < 4; ++i) { - std::stringstream ss; - ss << std::hex << str.substr(static_cast(i) * 16, 16); - ss >> data[3 - i]; + /* Quick and dirty conversion from a single character to its hex equivelent */ + constexpr auto HexCharToInt = [](uint8_t Input) { + bool valid = + (Input >= 'a' && Input <= 'f') || (Input >= 'A' && Input <= 'F') || (Input >= '0' && Input <= '9'); + if (!valid) { + throw_or_abort("Error, uint256 constructed from string_view with invalid hex parameter"); + } + uint8_t res = + ((Input >= 'a') && (Input <= 'f')) ? (Input - (static_cast('a') - static_cast(10))) + : ((Input >= 'A') && (Input <= 'F')) ? (Input - (static_cast('A') - static_cast(10))) + : ((Input >= '0') && (Input <= '9')) ? (Input - static_cast('0')) + : 0; + return res; + }; + + std::array limbs{ 0, 0, 0, 0 }; + size_t start_index = 0; + if (input.size() == 66 && input[0] == '0' && input[1] == 'x') { + start_index = 2; + } else if (input.size() != 64) { + throw_or_abort("Error, uint256 constructed from string_view with invalid length"); } + for (size_t j = 0; j < 4; ++j) { + + const size_t limb_index = start_index + j * 16; + for (size_t i = 0; i < 8; ++i) { + const size_t byte_index = limb_index + (i * 2); + uint8_t nibble_hi = HexCharToInt(static_cast(input[byte_index])); + uint8_t nibble_lo = HexCharToInt(static_cast(input[byte_index + 1])); + uint8_t byte = static_cast((nibble_hi * 16) + nibble_lo); + limbs[j] <<= 8; + limbs[j] += byte; + } + } + data[0] = limbs[3]; + data[1] = limbs[2]; + data[2] = limbs[1]; + data[3] = limbs[0]; } static constexpr uint256_t from_uint128(const uint128_t a) noexcept diff --git a/cpp/src/barretenberg/numeric/uint256/uint256.test.cpp b/cpp/src/barretenberg/numeric/uint256/uint256.test.cpp index 386119b386..28234d0c27 100644 --- a/cpp/src/barretenberg/numeric/uint256/uint256.test.cpp +++ b/cpp/src/barretenberg/numeric/uint256/uint256.test.cpp @@ -8,6 +8,27 @@ auto& engine = numeric::random::get_debug_engine(); using namespace numeric; +TEST(uint256, TestStringConstructors) +{ + std::string input = "9a807b615c4d3e2fa0b1c2d3e4f56789fedcba9876543210abcdef0123456789"; + const std::string input4("0x9a807b615c4d3e2fa0b1c2d3e4f56789fedcba9876543210abcdef0123456789"); + + const uint256_t result1(input); + constexpr uint256_t result2("9a807b615c4d3e2fa0b1c2d3e4f56789fedcba9876543210abcdef0123456789"); + const uint256_t result3("0x9a807b615c4d3e2fa0b1c2d3e4f56789fedcba9876543210abcdef0123456789"); + const uint256_t result4(input4); + constexpr uint256_t expected{ + 0xabcdef0123456789, + 0xfedcba9876543210, + 0xa0b1c2d3e4f56789, + 0x9a807b615c4d3e2f, + }; + EXPECT_EQ(result1, result2); + EXPECT_EQ(result1, result3); + EXPECT_EQ(result1, result4); + EXPECT_EQ(result1, expected); +} + TEST(uint256, GetBit) { constexpr uint256_t a{ 0b0110011001110010011001100111001001100110011100100110011001110011,