diff --git a/src/lib/pubkey/dilithium/dilithium/dilithium_modern.h b/src/lib/pubkey/dilithium/dilithium/dilithium_modern.h index 56e377a15e6..956550ac9d8 100644 --- a/src/lib/pubkey/dilithium/dilithium/dilithium_modern.h +++ b/src/lib/pubkey/dilithium/dilithium/dilithium_modern.h @@ -12,7 +12,6 @@ #include #include -#include #include #include @@ -23,26 +22,30 @@ namespace Botan { class Dilithium_Common_Symmetric_Primitives : public Dilithium_Symmetric_Primitives { public: - std::unique_ptr XOF(XofType type, std::span seed, uint16_t nonce) const override { - const auto xof_type = [&] { + Dilithium_Common_Symmetric_Primitives(size_t collision_strength_in_bytes) : + Dilithium_Symmetric_Primitives(collision_strength_in_bytes) {} + + Botan::XOF& XOF(XofType type, std::span seed, uint16_t nonce) const override { + auto& xof = [&]() -> Botan::XOF& { switch(type) { case XofType::k128: - return "SHAKE-128"; + return m_xof_128; case XofType::k256: - return "SHAKE-256"; + return m_xof_256; } BOTAN_ASSERT_UNREACHABLE(); }(); - std::array nonce_buffer; - store_le(nonce, nonce_buffer.data()); - - auto xof = Botan::XOF::create_or_throw(xof_type); - xof->update(seed); - xof->update(nonce_buffer); + xof.clear(); + xof.update(seed); + xof.update(store_le(nonce)); return xof; } + + private: + mutable SHAKE_256_XOF m_xof_256; + mutable SHAKE_128_XOF m_xof_128; }; } // namespace Botan diff --git a/src/lib/pubkey/dilithium/dilithium_aes/dilithium_aes.h b/src/lib/pubkey/dilithium/dilithium_aes/dilithium_aes.h index eff8cb4e278..882252311e4 100644 --- a/src/lib/pubkey/dilithium/dilithium_aes/dilithium_aes.h +++ b/src/lib/pubkey/dilithium/dilithium_aes/dilithium_aes.h @@ -22,8 +22,11 @@ namespace Botan { class Dilithium_AES_Symmetric_Primitives : public Dilithium_Symmetric_Primitives { public: + Dilithium_AES_Symmetric_Primitives(size_t collision_strength_in_bytes) : + Dilithium_Symmetric_Primitives(collision_strength_in_bytes) {} + // AES mode always uses AES-256, regardless of the XofType - std::unique_ptr XOF(XofType /* type */, std::span seed, uint16_t nonce) const final { + Botan::XOF& XOF(XofType /* type */, std::span seed, uint16_t nonce) const final { // Algorithm Spec V. 3.1 Section 5.3 // In the AES variant, the first 32 bytes of rhoprime are used as // the key and i is extended to a 12 byte nonce for AES-256 in @@ -36,10 +39,13 @@ class Dilithium_AES_Symmetric_Primitives : public Dilithium_Symmetric_Primitives const std::array iv{get_byte<1>(nonce), get_byte<0>(nonce), 0}; const auto key = seed.first(32); - auto xof = std::make_unique(); - xof->start(iv, key); - return xof; + m_aes_xof.clear(); + m_aes_xof.start(iv, key); + return m_aes_xof; } + + private: + mutable AES_256_CTR_XOF m_aes_xof; }; } // namespace Botan diff --git a/src/lib/pubkey/dilithium/dilithium_common/dilithium.cpp b/src/lib/pubkey/dilithium/dilithium_common/dilithium.cpp index 2c9387d4aa1..5b3acd2f56b 100644 --- a/src/lib/pubkey/dilithium/dilithium_common/dilithium.cpp +++ b/src/lib/pubkey/dilithium/dilithium_common/dilithium.cpp @@ -7,6 +7,7 @@ * (C) 2021-2023 Jack Lloyd * (C) 2021-2022 Manuel Glaser - Rohde & Schwarz Cybersecurity * (C) 2021-2023 Michael Boric, René Meusel - Rohde & Schwarz Cybersecurity +* (C) 2024 René Meusel - Rohde & Schwarz Cybersecurity * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -14,52 +15,18 @@ #include #include -#include #include -#include +#include +#include +#include #include -#include #include -#include #include -#include -#include -#include -#include -#include - namespace Botan { namespace { -std::pair calculate_t0_and_t1( - const DilithiumModeConstants& mode, - const std::vector& rho, - Dilithium::PolynomialVector s1, - const Dilithium::PolynomialVector& s2) { - /* Generate matrix */ - auto matrix = Dilithium::PolynomialMatrix::generate_matrix(rho, mode); - - /* Matrix-vector multiplication */ - s1.ntt(); - auto t = Dilithium::PolynomialVector::generate_polyvec_matrix_pointwise_montgomery(matrix.get_matrix(), s1, mode); - t.reduce(); - t.invntt_tomont(); - - /* Add error vector s2 */ - t.add_polyvec(s2); - - /* Extract t and write public key */ - t.cadd_q(); - - Dilithium::PolynomialVector t0(mode.k()); - Dilithium::PolynomialVector t1(mode.k()); - Dilithium::PolynomialVector::fill_polyvecs_power2round(t1, t0, t); - - return {std::move(t0), std::move(t1)}; -} - DilithiumMode::Mode dilithium_mode_from_string(std::string_view str) { if(str == "Dilithium-4x4-r3") { return DilithiumMode::Dilithium4x4; @@ -114,402 +81,272 @@ std::string DilithiumMode::to_string() const { class Dilithium_PublicKeyInternal { public: - Dilithium_PublicKeyInternal(DilithiumModeConstants mode) : m_mode(std::move(mode)) {} - - Dilithium_PublicKeyInternal(DilithiumModeConstants mode, std::span raw_pk) : - m_mode(std::move(mode)) { - BOTAN_ASSERT_NOMSG(raw_pk.size() == m_mode.public_key_bytes()); - - BufferSlicer s(raw_pk); - m_rho = s.copy_as_vector(DilithiumModeConstants::SEEDBYTES); - m_t1 = Dilithium::PolynomialVector::unpack_t1(s.take(DilithiumModeConstants::POLYT1_PACKEDBYTES * m_mode.k()), - m_mode); - - BOTAN_ASSERT_NOMSG(s.remaining() == 0); - BOTAN_STATE_CHECK(m_t1.m_vec.size() == m_mode.k()); - - m_raw_pk_shake256 = compute_raw_pk_shake256(); + static std::shared_ptr decode( + DilithiumConstants mode, StrongSpan raw_pk) { + auto [rho, t1] = dilithium_decode_public_key(raw_pk, mode); + return std::make_shared(std::move(mode), std::move(rho), std::move(t1)); } - Dilithium_PublicKeyInternal(DilithiumModeConstants mode, - std::vector rho, - const Dilithium::PolynomialVector& s1, - const Dilithium::PolynomialVector& s2) : + Dilithium_PublicKeyInternal(DilithiumConstants mode, DilithiumSeedRho rho, DilithiumPolyVec t1) : m_mode(std::move(mode)), m_rho(std::move(rho)), - m_t1([&] { return calculate_t0_and_t1(m_mode, m_rho, s1, s2).second; }()) { - BOTAN_ASSERT_NOMSG(!m_rho.empty()); - BOTAN_ASSERT_NOMSG(!m_t1.m_vec.empty()); - m_raw_pk_shake256 = compute_raw_pk_shake256(); - } - - Dilithium_PublicKeyInternal(DilithiumModeConstants mode, - std::vector rho, - Dilithium::PolynomialVector t1) : - m_mode(std::move(mode)), m_rho(std::move(rho)), m_t1(std::move(t1)) { + m_t1(std::move(t1)), + m_tr(m_mode.symmetric_primitives().H(raw_pk())) { BOTAN_ASSERT_NOMSG(!m_rho.empty()); - BOTAN_ASSERT_NOMSG(!m_t1.m_vec.empty()); - m_raw_pk_shake256 = compute_raw_pk_shake256(); + BOTAN_ASSERT_NOMSG(m_t1.size() > 0); } - ~Dilithium_PublicKeyInternal() = default; - - Dilithium_PublicKeyInternal(const Dilithium_PublicKeyInternal&) = delete; - Dilithium_PublicKeyInternal(Dilithium_PublicKeyInternal&&) = delete; - Dilithium_PublicKeyInternal& operator=(const Dilithium_PublicKeyInternal& other) = delete; - Dilithium_PublicKeyInternal& operator=(Dilithium_PublicKeyInternal&& other) = delete; + public: + DilithiumSerializedPublicKey raw_pk() const { return dilithium_encode_public_key(m_rho, m_t1, m_mode); } - std::vector raw_pk() const { return concat>(m_rho, m_t1.polyvec_pack_t1()); } + const DilithiumHashedPublicKey& tr() const { return m_tr; } - const std::vector& raw_pk_shake256() const { - BOTAN_STATE_CHECK(m_raw_pk_shake256.size() == DilithiumModeConstants::SEEDBYTES); - return m_raw_pk_shake256; - } + const DilithiumPolyVec& t1() const { return m_t1; } - const Dilithium::PolynomialVector& t1() const { return m_t1; } + const DilithiumSeedRho& rho() const { return m_rho; } - const std::vector& rho() const { return m_rho; } - - const DilithiumModeConstants& mode() const { return m_mode; } + const DilithiumConstants& mode() const { return m_mode; } private: - std::vector compute_raw_pk_shake256() const { - SHAKE_256 shake(DilithiumModeConstants::SEEDBYTES * 8); - shake.update(m_rho); - shake.update(m_t1.polyvec_pack_t1()); - return shake.final_stdvec(); - } - - const DilithiumModeConstants m_mode; - std::vector m_raw_pk_shake256; - std::vector m_rho; - Dilithium::PolynomialVector m_t1; + const DilithiumConstants m_mode; + DilithiumSeedRho m_rho; + DilithiumPolyVec m_t1; + DilithiumHashedPublicKey m_tr; }; class Dilithium_PrivateKeyInternal { public: - Dilithium_PrivateKeyInternal(DilithiumModeConstants mode) : m_mode(std::move(mode)) {} - - Dilithium_PrivateKeyInternal(DilithiumModeConstants mode, - std::vector rho, - secure_vector tr, - secure_vector key, - Dilithium::PolynomialVector s1, - Dilithium::PolynomialVector s2, - Dilithium::PolynomialVector t0) : + static std::shared_ptr decode(DilithiumConstants mode, + StrongSpan sk) { + auto [rho, signing_seed, tr, s1, s2, t0] = dilithium_decode_private_key(sk, mode); + return std::make_shared(std::move(mode), + std::move(rho), + std::move(signing_seed), + std::move(tr), + std::move(s1), + std::move(s2), + std::move(t0)); + } + + Dilithium_PrivateKeyInternal(DilithiumConstants mode, + DilithiumSeedRho rho, + DilithiumSigningSeedK signing_seed, + DilithiumHashedPublicKey tr, + DilithiumPolyVec s1, + DilithiumPolyVec s2, + DilithiumPolyVec t0) : m_mode(std::move(mode)), m_rho(std::move(rho)), + m_signing_seed(std::move(signing_seed)), m_tr(std::move(tr)), - m_key(std::move(key)), m_t0(std::move(t0)), m_s1(std::move(s1)), m_s2(std::move(s2)) {} - Dilithium_PrivateKeyInternal(DilithiumModeConstants mode, std::span sk) : - Dilithium_PrivateKeyInternal(std::move(mode)) { - BOTAN_ASSERT_NOMSG(sk.size() == m_mode.private_key_bytes()); - - BufferSlicer s(sk); - m_rho = s.copy_as_vector(DilithiumModeConstants::SEEDBYTES); - m_key = s.copy_as_secure_vector(DilithiumModeConstants::SEEDBYTES); - m_tr = s.copy_as_secure_vector(DilithiumModeConstants::SEEDBYTES); - m_s1 = Dilithium::PolynomialVector::unpack_eta( - s.take(m_mode.l() * m_mode.polyeta_packedbytes()), m_mode.l(), m_mode); - m_s2 = Dilithium::PolynomialVector::unpack_eta( - s.take(m_mode.k() * m_mode.polyeta_packedbytes()), m_mode.k(), m_mode); - m_t0 = Dilithium::PolynomialVector::unpack_t0(s.take(m_mode.k() * DilithiumModeConstants::POLYT0_PACKEDBYTES), - m_mode); - } - - secure_vector raw_sk() const { - return concat>( - m_rho, m_key, m_tr, m_s1.polyvec_pack_eta(m_mode), m_s2.polyvec_pack_eta(m_mode), m_t0.polyvec_pack_t0()); + public: + DilithiumSerializedPrivateKey raw_sk() const { + return dilithium_encode_private_key(m_rho, m_tr, m_signing_seed, m_s1, m_s2, m_t0, m_mode); } - const DilithiumModeConstants& mode() const { return m_mode; } + const DilithiumConstants& mode() const { return m_mode; } - const std::vector& rho() const { return m_rho; } + const DilithiumSeedRho& rho() const { return m_rho; } - const secure_vector& get_key() const { return m_key; } + const DilithiumSigningSeedK& signing_seed() const { return m_signing_seed; } - const secure_vector& tr() const { return m_tr; } + const DilithiumHashedPublicKey& tr() const { return m_tr; } - const Dilithium::PolynomialVector& s1() const { return m_s1; } + const DilithiumPolyVec& s1() const { return m_s1; } - const Dilithium::PolynomialVector& s2() const { return m_s2; } + const DilithiumPolyVec& s2() const { return m_s2; } - const Dilithium::PolynomialVector& t0() const { return m_t0; } + const DilithiumPolyVec& t0() const { return m_t0; } private: - const DilithiumModeConstants m_mode; - std::vector m_rho; - secure_vector m_tr, m_key; - Dilithium::PolynomialVector m_t0, m_s1, m_s2; + const DilithiumConstants m_mode; + DilithiumSeedRho m_rho; + DilithiumSigningSeedK m_signing_seed; + DilithiumHashedPublicKey m_tr; + DilithiumPolyVec m_t0; + DilithiumPolyVec m_s1; + DilithiumPolyVec m_s2; }; class Dilithium_Signature_Operation final : public PK_Ops::Signature { public: - Dilithium_Signature_Operation(const Dilithium_PrivateKey& priv_key_dilithium, bool randomized) : - m_priv_key(priv_key_dilithium), - m_matrix( - Dilithium::PolynomialMatrix::generate_matrix(m_priv_key.m_private->rho(), m_priv_key.m_private->mode())), - m_shake(DilithiumModeConstants::CRHBYTES * 8), - m_randomized(randomized) { - m_shake.update(m_priv_key.m_private->tr()); - } - - void update(const uint8_t msg[], size_t msg_len) override { m_shake.update(msg, msg_len); } - + Dilithium_Signature_Operation(std::shared_ptr sk, bool randomized) : + m_priv_key(std::move(sk)), + m_mode(m_priv_key->mode()), + m_randomized(randomized), + m_h(m_mode.symmetric_primitives().get_message_hash(m_priv_key->tr())), + m_s1(ntt(m_priv_key->s1().clone())), + m_s2(ntt(m_priv_key->s2().clone())), + m_t0(ntt(m_priv_key->t0().clone())), + m_A(dilithium_expand_A(m_priv_key->rho(), m_mode)) {} + + void update(const uint8_t msg[], size_t msg_len) override { m_h.update({msg, msg_len}); } + + /** + * NIST FIPS 204 IPD, Algorithm 2 (ML-DSA.Sign) + * + * Note that the private key decoding is done ahead of time. Also, the + * matrix expansion of A from 'rho' along with the NTT-transforms of s1, + * s2 and t0 are done in the constructor of this class, as a 'signature + * operation' may be used to sign multiple messages. + */ secure_vector sign(RandomNumberGenerator& rng) override { - const auto mu = m_shake.final_stdvec(); - - // Get set up for the next message (if any) - m_shake.update(m_priv_key.m_private->tr()); + const auto mu = m_h.final(); + const auto& sympri = m_mode.symmetric_primitives(); - const auto& mode = m_priv_key.m_private->mode(); - - const auto rhoprime = (m_randomized) ? rng.random_vec(DilithiumModeConstants::CRHBYTES) - : mode.CRH(concat(m_priv_key.m_private->get_key(), mu)); - - /* Transform vectors */ - auto s1 = m_priv_key.m_private->s1(); - s1.ntt(); - - auto s2 = m_priv_key.m_private->s2(); - s2.ntt(); - - auto t0 = m_priv_key.m_private->t0(); - t0.ntt(); + // TODO: ML-DSA generates rhoprime differently, namely + // rhoprime = H(K, rnd, mu) with rnd being 32 random bytes or 32 zero bytes + const auto rhoprime = (m_randomized) + ? rng.random_vec(DilithiumConstants::SEED_RHOPRIME_BYTES) + : sympri.H(m_priv_key->signing_seed(), mu); // Note: nonce (as requested by `polyvecl_uniform_gamma1`) is actually just uint16_t // but to avoid an integer overflow, we use uint32_t as the loop variable. - for(uint32_t nonce = 0; nonce <= std::numeric_limits::max(); ++nonce) { - /* Sample intermediate vector y */ - Dilithium::PolynomialVector y(mode.l()); - - y.polyvecl_uniform_gamma1(rhoprime, static_cast(nonce), mode); - - auto z = y; - z.ntt(); - - /* Matrix-vector multiplication */ - auto w1 = Dilithium::PolynomialVector::generate_polyvec_matrix_pointwise_montgomery( - m_matrix.get_matrix(), z, mode); - - w1.reduce(); - w1.invntt_tomont(); - - /* Decompose w and call the random oracle */ - w1.cadd_q(); - - auto w1_w0 = w1.polyvec_decompose(mode); - - auto packed_w1 = std::get<0>(w1_w0).polyvec_pack_w1(mode); - - SHAKE_256 shake256_variable(DilithiumModeConstants::SEEDBYTES * 8); - shake256_variable.update(mu.data(), DilithiumModeConstants::CRHBYTES); - shake256_variable.update(packed_w1.data(), packed_w1.size()); - auto sm = shake256_variable.final(); - - auto cp = Dilithium::Polynomial::poly_challenge(sm.data(), mode); - cp.ntt(); - - /* Compute z, reject if it reveals secret */ - s1.polyvec_pointwise_poly_montgomery(z, cp); - - z.invntt_tomont(); - z.add_polyvec(y); - + for(uint32_t nonce = 0; nonce <= std::numeric_limits::max(); nonce += m_mode.l()) { + const auto y = dilithium_expand_mask(rhoprime, static_cast(nonce), m_mode); + + auto w_ntt = m_A * ntt(y.clone()); + w_ntt.reduce(); + auto w = inverse_ntt(std::move(w_ntt)); + w.conditional_add_q(); + + auto [w1, w0] = dilithium_decompose(w, m_mode); + const auto ch = sympri.H(mu, dilithium_encode_commitment(w1, m_mode)); + StrongSpan c1( + std::span(ch).first(DilithiumConstants::COMMITMENT_HASH_C1_BYTES)); + const auto c = ntt(dilithium_sample_in_ball(c1, m_mode)); + + auto cs1 = inverse_ntt(c * m_s1); + auto z = y + cs1; z.reduce(); - if(z.polyvec_chknorm(mode.gamma1() - mode.beta())) { + if(!dilithium_infinity_norm_within_bound(z, to_underlying(m_mode.gamma1()) - m_mode.beta())) { continue; } - /* Check that subtracting cs2 does not change high bits of w and low bits - * do not reveal secret information */ - Dilithium::PolynomialVector h(mode.k()); - s2.polyvec_pointwise_poly_montgomery(h, cp); - h.invntt_tomont(); - std::get<1>(w1_w0) -= h; - std::get<1>(w1_w0).reduce(); - - if(std::get<1>(w1_w0).polyvec_chknorm(mode.gamma2() - mode.beta())) { + auto cs2 = inverse_ntt(c * m_s2); + w0 -= cs2; + w0.reduce(); + if(!dilithium_infinity_norm_within_bound(w0, to_underlying(m_mode.gamma2()) - m_mode.beta())) { continue; } - /* Compute hints for w1 */ - t0.polyvec_pointwise_poly_montgomery(h, cp); - h.invntt_tomont(); - h.reduce(); - if(h.polyvec_chknorm(mode.gamma2())) { + auto ct0 = inverse_ntt(c * m_t0); + ct0.reduce(); + if(!dilithium_infinity_norm_within_bound(ct0, m_mode.gamma2())) { continue; } - std::get<1>(w1_w0).add_polyvec(h); - std::get<1>(w1_w0).cadd_q(); + w0 += ct0; + w0.conditional_add_q(); - auto n = - Dilithium::PolynomialVector::generate_hint_polyvec(h, std::get<1>(w1_w0), std::get<0>(w1_w0), mode); - if(n > mode.omega()) { + auto hint = dilithium_make_hint(w0, w1, m_mode); + if(hint.hamming_weight() > m_mode.omega()) { continue; } - /* Write signature */ - return pack_sig(sm, z, h); + return dilithium_encode_signature(ch, z, hint, m_mode).get(); } throw Internal_Error("Dilithium signature loop did not terminate"); } - size_t signature_length() const override { - const auto& dilithium_math = m_priv_key.m_private->mode(); - return dilithium_math.crypto_bytes(); - } + size_t signature_length() const override { return m_priv_key->mode().signature_bytes(); } - AlgorithmIdentifier algorithm_identifier() const override; + AlgorithmIdentifier algorithm_identifier() const override { + return AlgorithmIdentifier(m_priv_key->mode().mode().object_identifier(), + AlgorithmIdentifier::USE_EMPTY_PARAM); + } - std::string hash_function() const override { return "SHAKE-256(512)"; } + std::string hash_function() const override { return m_h.name(); } private: - // Bit-pack signature sig = (c, z, h). - secure_vector pack_sig(const secure_vector& c, - const Dilithium::PolynomialVector& z, - const Dilithium::PolynomialVector& h) { - BOTAN_ASSERT_NOMSG(c.size() == DilithiumModeConstants::SEEDBYTES); - size_t position = 0; - const auto& mode = m_priv_key.m_private->mode(); - secure_vector sig(mode.crypto_bytes()); - - std::copy(c.begin(), c.end(), sig.begin()); - position += DilithiumModeConstants::SEEDBYTES; - - for(size_t i = 0; i < mode.l(); ++i) { - z.m_vec[i].polyz_pack(&sig[position + i * mode.polyz_packedbytes()], mode); - } - position += mode.l() * mode.polyz_packedbytes(); - - /* Encode h */ - for(size_t i = 0; i < mode.omega() + mode.k(); ++i) { - sig[i + position] = 0; - } - - size_t k = 0; - for(size_t i = 0; i < mode.k(); ++i) { - for(size_t j = 0; j < DilithiumModeConstants::N; ++j) { - if(h.m_vec[i].m_coeffs[j] != 0) { - sig[position + k] = static_cast(j); - k++; - } - } - sig[position + mode.omega() + i] = static_cast(k); - } - return sig; - } - - const Dilithium_PrivateKey m_priv_key; - const Dilithium::PolynomialMatrix m_matrix; - SHAKE_256 m_shake; + std::shared_ptr m_priv_key; + const DilithiumConstants& m_mode; bool m_randomized; -}; + DilithiumMessageHash m_h; -AlgorithmIdentifier Dilithium_Signature_Operation::algorithm_identifier() const { - return m_priv_key.algorithm_identifier(); -} + const DilithiumPolyVecNTT m_s1; + const DilithiumPolyVecNTT m_s2; + const DilithiumPolyVecNTT m_t0; + const DilithiumPolyMatNTT m_A; +}; class Dilithium_Verification_Operation final : public PK_Ops::Verification { public: - Dilithium_Verification_Operation(const Dilithium_PublicKey& pub_dilithium) : - m_pub_key(pub_dilithium.m_public), - m_matrix(Dilithium::PolynomialMatrix::generate_matrix(m_pub_key->rho(), m_pub_key->mode())), - m_pk_hash(m_pub_key->raw_pk_shake256()), - m_shake(DilithiumModeConstants::CRHBYTES * 8) { - m_shake.update(m_pk_hash); - } - - /* - * Add more data to the message currently being signed - * @param msg the message - * @param msg_len the length of msg in bytes - */ - void update(const uint8_t msg[], size_t msg_len) override { m_shake.update(msg, msg_len); } - - /* - * Perform a verification operation - * @param rng a random number generator - */ + Dilithium_Verification_Operation(std::shared_ptr pubkey) : + m_pub_key(std::move(pubkey)), + m_mode(m_pub_key->mode()), + m_A(dilithium_expand_A(m_pub_key->rho(), m_mode)), + m_h(m_mode.symmetric_primitives().get_message_hash(m_pub_key->tr())) {} + + void update(const uint8_t msg[], size_t msg_len) override { m_h.update({msg, msg_len}); } + + /** + * NIST FIPS 204 IPD, Algorithm 3 (ML-DSA.Verify) + * + * Note that the public key decoding is done ahead of time. Also, the + * matrix A is expanded from 'rho' in the constructor of this class, as + * a 'verification operation' may be used to verify multiple signatures. + */ bool is_valid_signature(const uint8_t* sig, size_t sig_len) override { - /* Compute CRH(H(rho, t1), msg) */ - const auto mu = m_shake.final_stdvec(); - - // Reset the SHAKE context for the next message - m_shake.update(m_pk_hash); - - const auto& mode = m_pub_key->mode(); + const auto& sympri = m_mode.symmetric_primitives(); + StrongSpan sig_bytes({sig, sig_len}); - if(sig_len != mode.crypto_bytes()) { + if(sig_bytes.size() != m_mode.signature_bytes()) { return false; } - Dilithium::PolynomialVector z(mode.l()); - Dilithium::PolynomialVector h(mode.k()); - std::vector signature(sig, sig + sig_len); - std::array c; - if(Dilithium::PolynomialVector::unpack_sig(c, z, h, signature, mode)) { + const auto mu = m_h.final(); + + auto signature = dilithium_decode_signature(sig_bytes, m_mode); + if(!signature.has_value()) { return false; } + auto [ch, z, h] = std::move(signature.value()); + StrongSpan c1( + std::span(ch).first(DilithiumConstants::COMMITMENT_HASH_C1_BYTES)); - if(z.polyvec_chknorm(mode.gamma1() - mode.beta())) { + if(h.hamming_weight() > m_mode.omega() || + !dilithium_infinity_norm_within_bound(z, to_underlying(m_mode.gamma1()) - m_mode.beta())) { return false; } - /* Matrix-vector multiplication; compute Az - c2^dt1 */ - auto cp = Dilithium::Polynomial::poly_challenge(c.data(), mode); - cp.ntt(); - - Dilithium::PolynomialVector t1 = m_pub_key->t1(); - t1.polyvec_shiftl(); - t1.ntt(); - t1.polyvec_pointwise_poly_montgomery(t1, cp); - - z.ntt(); - - auto w1 = - Dilithium::PolynomialVector::generate_polyvec_matrix_pointwise_montgomery(m_matrix.get_matrix(), z, mode); - w1 -= t1; - w1.reduce(); - w1.invntt_tomont(); - w1.cadd_q(); - w1.polyvec_use_hint(w1, h, mode); - auto packed_w1 = w1.polyvec_pack_w1(mode); - - /* Call random oracle and verify challenge */ - SHAKE_256 shake256_variable(DilithiumModeConstants::SEEDBYTES * 8); - shake256_variable.update(mu.data(), mu.size()); - shake256_variable.update(packed_w1.data(), packed_w1.size()); - auto c2 = shake256_variable.final(); - - BOTAN_ASSERT_NOMSG(c.size() == c2.size()); - return std::equal(c.begin(), c.end(), c2.begin()); + const auto c_hat = ntt(dilithium_sample_in_ball(c1, m_mode)); + auto w_approx = m_A * ntt(std::move(z)); + w_approx -= c_hat * ntt(m_pub_key->t1() << DilithiumConstants::D); + w_approx.reduce(); + auto w1 = inverse_ntt(std::move(w_approx)); + w1.conditional_add_q(); + dilithium_use_hint(w1, h, m_mode); + + const auto chprime = sympri.H(mu, dilithium_encode_commitment(w1, m_mode)); + + BOTAN_ASSERT_NOMSG(ch.size() == chprime.size()); + return std::equal(ch.begin(), ch.end(), chprime.begin()); } - std::string hash_function() const override { return "SHAKE-256(512)"; } + std::string hash_function() const override { return m_h.name(); } private: std::shared_ptr m_pub_key; - const Dilithium::PolynomialMatrix m_matrix; - const std::vector m_pk_hash; - SHAKE_256 m_shake; + const DilithiumConstants& m_mode; + DilithiumPolyMatNTT m_A; + DilithiumMessageHash m_h; }; Dilithium_PublicKey::Dilithium_PublicKey(const AlgorithmIdentifier& alg_id, std::span pk) : Dilithium_PublicKey(pk, DilithiumMode(alg_id.oid())) {} Dilithium_PublicKey::Dilithium_PublicKey(std::span pk, DilithiumMode m) { - DilithiumModeConstants mode(m); + DilithiumConstants mode(m); BOTAN_ARG_CHECK(pk.empty() || pk.size() == mode.public_key_bytes(), "dilithium public key does not have the correct byte count"); - m_public = std::make_shared(std::move(mode), pk); + m_public = Dilithium_PublicKeyInternal::decode(std::move(mode), StrongSpan(pk)); } std::string Dilithium_PublicKey::algo_name() const { @@ -521,7 +358,7 @@ AlgorithmIdentifier Dilithium_PublicKey::algorithm_identifier() const { } OID Dilithium_PublicKey::object_identifier() const { - return m_public->mode().oid(); + return m_public->mode().mode().object_identifier(); } size_t Dilithium_PublicKey::key_length() const { @@ -529,11 +366,11 @@ size_t Dilithium_PublicKey::key_length() const { } size_t Dilithium_PublicKey::estimated_strength() const { - return m_public->mode().nist_security_strength(); + return m_public->mode().lambda(); } std::vector Dilithium_PublicKey::raw_public_key_bits() const { - return m_public->raw_pk(); + return m_public->raw_pk().get(); } std::vector Dilithium_PublicKey::public_key_bits() const { @@ -554,7 +391,7 @@ std::unique_ptr Dilithium_PublicKey::create_verification_o std::string_view provider) const { BOTAN_ARG_CHECK(params.empty() || params == "Pure", "Unexpected parameters for verifying with Dilithium"); if(provider.empty() || provider == "base") { - return std::make_unique(*this); + return std::make_unique(m_public); } throw Provider_Not_Found(algo_name(), provider); } @@ -565,56 +402,64 @@ std::unique_ptr Dilithium_PublicKey::create_x509_verificat if(alg_id != this->algorithm_identifier()) { throw Decoding_Error("Unexpected AlgorithmIdentifier for Dilithium X.509 signature"); } - return std::make_unique(*this); + return std::make_unique(m_public); } throw Provider_Not_Found(algo_name(), provider); } -Dilithium_PrivateKey::Dilithium_PrivateKey(RandomNumberGenerator& rng, DilithiumMode m) { - DilithiumModeConstants mode(m); - - secure_vector seedbuf = rng.random_vec(DilithiumModeConstants::SEEDBYTES); +namespace { - auto seed = mode.H(seedbuf, 2 * DilithiumModeConstants::SEEDBYTES + DilithiumModeConstants::CRHBYTES); +std::pair dilithium_compute_t1_and_t0(const DilithiumPolyMatNTT& A, + const DilithiumPolyVec& s1, + const DilithiumPolyVec& s2) { + auto t_hat = A * ntt(s1.clone()); + t_hat.reduce(); + auto t = inverse_ntt(std::move(t_hat)); + t += s2; + t.conditional_add_q(); - // seed is a concatenation of rho || rhoprime || key - std::vector rho(seed.begin(), seed.begin() + DilithiumModeConstants::SEEDBYTES); - secure_vector rhoprime(seed.begin() + DilithiumModeConstants::SEEDBYTES, - seed.begin() + DilithiumModeConstants::SEEDBYTES + DilithiumModeConstants::CRHBYTES); - secure_vector key(seed.begin() + DilithiumModeConstants::SEEDBYTES + DilithiumModeConstants::CRHBYTES, - seed.end()); + return dilithium_power2round(t); +} - BOTAN_ASSERT_NOMSG(rho.size() == DilithiumModeConstants::SEEDBYTES); - BOTAN_ASSERT_NOMSG(rhoprime.size() == DilithiumModeConstants::CRHBYTES); - BOTAN_ASSERT_NOMSG(key.size() == DilithiumModeConstants::SEEDBYTES); +} // namespace - /* Sample short vectors s1 and s2 */ - Dilithium::PolynomialVector s1(mode.l()); - Dilithium::PolynomialVector::fill_polyvec_uniform_eta(s1, rhoprime, 0, mode); +/** + * NIST FIPS 204 IPD, Algorithm 1 (ML-DSA.KeyGen) + */ +Dilithium_PrivateKey::Dilithium_PrivateKey(RandomNumberGenerator& rng, DilithiumMode m) { + DilithiumConstants mode(m); + const auto& sympriv = mode.symmetric_primitives(); - Dilithium::PolynomialVector s2(mode.k()); - Dilithium::PolynomialVector::fill_polyvec_uniform_eta(s2, rhoprime, mode.l(), mode); + const auto xi = rng.random_vec(DilithiumConstants::SEED_RANDOMNESS_BYTES); + auto [rho, rhoprime, key] = sympriv.H(xi); - auto [t0, t1] = calculate_t0_and_t1(mode, rho, s1, s2); + const auto A = dilithium_expand_A(rho, mode); + auto [s1, s2] = dilithium_expand_s(rhoprime, mode); + auto [t1, t0] = dilithium_compute_t1_and_t0(A, s1, s2); m_public = std::make_shared(mode, rho, std::move(t1)); - - /* Compute H(rho, t1) == H(pk) and write secret key */ - auto tr = mode.H(m_public->raw_pk(), DilithiumModeConstants::SEEDBYTES); - m_private = std::make_shared( - std::move(mode), std::move(rho), std::move(tr), std::move(key), std::move(s1), std::move(s2), std::move(t0)); + std::move(mode), std::move(rho), std::move(key), m_public->tr(), std::move(s1), std::move(s2), std::move(t0)); } Dilithium_PrivateKey::Dilithium_PrivateKey(const AlgorithmIdentifier& alg_id, std::span sk) : Dilithium_PrivateKey(sk, DilithiumMode(alg_id.oid())) {} Dilithium_PrivateKey::Dilithium_PrivateKey(std::span sk, DilithiumMode m) { - DilithiumModeConstants mode(m); + DilithiumConstants mode(m); BOTAN_ARG_CHECK(sk.size() == mode.private_key_bytes(), "dilithium private key does not have the correct byte count"); - m_private = std::make_shared(std::move(mode), sk); - m_public = std::make_shared( - m_private->mode(), m_private->rho(), m_private->s1(), m_private->s2()); + m_private = + Dilithium_PrivateKeyInternal::decode(std::move(mode), StrongSpan(sk)); + + // Currently, Botan's Private_Key class inherits from Public_Key, forcing us + // to derive the public key from the private key here. + const auto A = dilithium_expand_A(m_private->rho(), m_private->mode()); + auto [t1, _] = dilithium_compute_t1_and_t0(A, m_private->s1(), m_private->s2()); + m_public = std::make_shared(m_private->mode(), m_private->rho(), std::move(t1)); + + if(m_public->tr() != m_private->tr()) { + throw Decoding_Error("Calculated dilithium public key hash does not match the one stored in the private key"); + } } secure_vector Dilithium_PrivateKey::raw_private_key_bits() const { @@ -622,7 +467,7 @@ secure_vector Dilithium_PrivateKey::raw_private_key_bits() const { } secure_vector Dilithium_PrivateKey::private_key_bits() const { - return m_private->raw_sk(); + return std::move(m_private->raw_sk().get()); } std::unique_ptr Dilithium_PrivateKey::create_signature_op(RandomNumberGenerator& rng, @@ -633,9 +478,11 @@ std::unique_ptr Dilithium_PrivateKey::create_signature_op(Ran BOTAN_ARG_CHECK(params.empty() || params == "Deterministic" || params == "Randomized", "Unexpected parameters for signing with Dilithium"); + // TODO: ML-DSA uses the randomized (hedged) variant by default. + // We might even drop support for the deterministic variant. const bool randomized = (params == "Randomized"); if(provider.empty() || provider == "base") { - return std::make_unique(*this, randomized); + return std::make_unique(m_private, randomized); } throw Provider_Not_Found(algo_name(), provider); } diff --git a/src/lib/pubkey/dilithium/dilithium_common/dilithium_algos.cpp b/src/lib/pubkey/dilithium/dilithium_common/dilithium_algos.cpp new file mode 100644 index 00000000000..4b95c08a589 --- /dev/null +++ b/src/lib/pubkey/dilithium/dilithium_common/dilithium_algos.cpp @@ -0,0 +1,781 @@ +/* + * Crystals Dilithium Internal Algorithms (aka. "Auxiliary Functions") + * + * This implements the auxiliary functions of the Crystals Dilithium signature + * scheme as specified in NIST FIPS 204 IPD, Chapter 8. + * + * Some implementations are based on the public domain reference implementation + * by the designers (https://github.com/pq-crystals/dilithium) + * + * (C) 2021-2024 Jack Lloyd + * (C) 2021-2022 Manuel Glaser and Michael Boric, Rohde & Schwarz Cybersecurity + * (C) 2021-2022 René Meusel and Hannes Rantzsch, neXenio GmbH + * (C) 2024 Fabian Albert and René Meusel, Rohde & Schwarz Cybersecurity + * + * Botan is released under the Simplified BSD License (see license.txt) + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace Botan { + +namespace { + +/** + * Returns an all-one mask if @p x is negative, otherwise an all-zero mask. + */ +template +constexpr auto is_negative_mask(T x) { + using unsigned_T = std::make_unsigned_t; + return CT::Mask::expand_top_bit(static_cast(x)); +} + +template +constexpr std::make_unsigned_t map_range(DilithiumConstants::T c) { + BOTAN_DEBUG_ASSERT(b - c >= 0); + return b - c; +} + +template +constexpr DilithiumConstants::T unmap_range(std::make_unsigned_t c) { + return static_cast(b - c); +} + +template +constexpr void poly_pack(const CRYSTALS::Polynomial& p, BufferStuffer& stuffer) { + if constexpr(a == 0) { + // If `a` is 0, we assume SimpleBitPack (Algorithm 10) where the + // coefficients are in the range [0, b]. + CRYSTALS::pack(p, stuffer); + } else { + // Otherwise, for BitPack (Algorithm 11), we must map the coefficients to + // positive values as they are in the range [-a, b]. + CRYSTALS::pack(p, stuffer, map_range); + } +} + +template +constexpr void poly_unpack(CRYSTALS::Polynomial& p, ByteSourceT& get_bytes, bool check_range = false) { + if constexpr(a == 0) { + // If `a` is 0, we assume SimpleBitUnpack (Algorithm 12) where the + // coefficients are in the range [0, b]. + CRYSTALS::unpack(p, get_bytes); + } else { + // Otherwise, BitUnpack (Algorithm 13) must map the unpacked coefficients + // to the range [-a, b]. + CRYSTALS::unpack(p, get_bytes, unmap_range); + } + if(check_range && !p.ct_validate_value_range(-a, b)) { + throw Decoding_Error("Decoded polynomial coefficients out of range"); + } +} + +/** + * NIST FIPS 204 IPD, Algorithm 10 (SimpleBitPack) + * (for a = 2^(bitlen(q-1)-d) - 1) + */ +void dilithium_poly_pack_t1(const DilithiumPoly& p, BufferStuffer& stuffer) { + constexpr auto b = (1 << (bitlen(DilithiumConstants::Q - 1) - DilithiumConstants::D)) - 1; + poly_pack<0, b>(p, stuffer); +} + +/** + * NIST FIPS 204 IPD, Algorithm 10 (SimpleBitPack) + * (for a = (q-1)/(2*gamma2-1)) + */ +void dilithium_poly_pack_w1(const DilithiumPoly& p, BufferStuffer& stuffer, const DilithiumConstants& mode) { + using Gamma2 = DilithiumConstants::DilithiumGamma2; + auto calculate_b = [](auto gamma2) { return ((DilithiumConstants::Q - 1) / (2 * gamma2)) - 1; }; + switch(mode.gamma2()) { + case Gamma2::Qminus1DevidedBy88: + return poly_pack<0, calculate_b(Gamma2::Qminus1DevidedBy88)>(p, stuffer); + case Gamma2::Qminus1DevidedBy32: + return poly_pack<0, calculate_b(Gamma2::Qminus1DevidedBy32)>(p, stuffer); + } + + BOTAN_ASSERT_UNREACHABLE(); +} + +/** + * NIST FIPS 204 IPD, Algorithm 11 (BitPack) + * (for a = -gamma1 - 1, b = gamma1) + */ +void dilithium_poly_pack_gamma1(const DilithiumPoly& p, BufferStuffer& stuffer, const DilithiumConstants& mode) { + using Gamma1 = DilithiumConstants::DilithiumGamma1; + switch(mode.gamma1()) { + case Gamma1::ToThe17th: + return poly_pack(p, stuffer); + case Gamma1::ToThe19th: + return poly_pack(p, stuffer); + } + + BOTAN_ASSERT_UNREACHABLE(); +} + +/** + * NIST FIPS 204 IPD, Algorithm 11 (BitPack) + * (for a = -eta, b = eta) + */ +void dilithium_poly_pack_eta(const DilithiumPoly& p, BufferStuffer& stuffer, const DilithiumConstants& mode) { + using Eta = DilithiumConstants::DilithiumEta; + switch(mode.eta()) { + case Eta::_2: + return poly_pack(p, stuffer); + case Eta::_4: + return poly_pack(p, stuffer); + } + + BOTAN_ASSERT_UNREACHABLE(); +} + +/** + * NIST FIPS 204 IPD, Algorithm 11 (BitPack) + * (for a = -2^(d-1) - 1, b = 2^(d-1)) + */ +void dilithium_poly_pack_t0(const DilithiumPoly& p, BufferStuffer& stuffer) { + constexpr auto TwoToTheDminus1 = 1 << (DilithiumConstants::D - 1); + poly_pack(p, stuffer); +} + +/** + * NIST FIPS 204 IPD, Algorithm 12 (SimpleBitUnpack) + * (for a = 2^(bitlen(q-1)-d) - 1) + */ +void dilithium_poly_unpack_t1(DilithiumPoly& p, BufferSlicer& slicer) { + constexpr auto b = (1 << (bitlen(DilithiumConstants::Q - 1) - DilithiumConstants::D)) - 1; + poly_unpack<0, b>(p, slicer); +} + +/** + * NIST FIPS 204 IPD, Algorithm 13 (BitUnpack) + * (for a = -gamma1 - 1, b = gamma1) + */ +template +void dilithium_poly_unpack_gamma1(DilithiumPoly& p, ByteSourceT& byte_source, const DilithiumConstants& mode) { + using Gamma1 = DilithiumConstants::DilithiumGamma1; + switch(mode.gamma1()) { + case Gamma1::ToThe17th: + return poly_unpack(p, byte_source); + case Gamma1::ToThe19th: + return poly_unpack(p, byte_source); + } + + BOTAN_ASSERT_UNREACHABLE(); +} + +/** + * NIST FIPS 204 IPD, Algorithm 13 (BitUnpack) + * (for a = -eta, b = eta) + */ +void dilithium_poly_unpack_eta(DilithiumPoly& p, + BufferSlicer& slicer, + const DilithiumConstants& mode, + bool check_range = false) { + using Eta = DilithiumConstants::DilithiumEta; + switch(mode.eta()) { + case Eta::_2: + return poly_unpack(p, slicer, check_range); + case Eta::_4: + return poly_unpack(p, slicer, check_range); + } + + BOTAN_ASSERT_UNREACHABLE(); +} + +/** + * NIST FIPS 204 IPD, Algorithm 13 (BitUnpack) + * (for a = -2^(d-1) - 1, b = 2^(d-1)) + */ +void dilithium_poly_unpack_t0(DilithiumPoly& p, BufferSlicer& slicer) { + constexpr auto TwoToTheDminus1 = 1 << (DilithiumConstants::D - 1); + poly_unpack(p, slicer); +} + +/** + * NIST FIPS 204 IPD, Algorithm 14 (HintBitPack) + */ +void dilithium_hint_pack(const DilithiumPolyVec& h, BufferStuffer& stuffer, const DilithiumConstants& mode) { + BOTAN_ASSERT_NOMSG(h.size() == mode.k()); + BOTAN_DEBUG_ASSERT(h.ct_validate_value_range(0, 1)); + + BufferStuffer bit_positions(stuffer.next(mode.omega())); + BufferStuffer offsets(stuffer.next(mode.k())); + + uint8_t index = 0; + for(const auto& p : h) { + for(size_t i = 0; i < p.size(); ++i) { + if(p[i] == 1) { + bit_positions.append(static_cast(i)); + ++index; + } + } + offsets.append(index); + } + + // Fill the remaining bit positions with zeros + bit_positions.append(0, bit_positions.remaining_capacity()); +} + +/** + * NIST FIPS 204 IPD, Algorithm 15 (HintBitUnpack) + */ +std::optional dilithium_hint_unpack(BufferSlicer& slicer, const DilithiumConstants& mode) { + BufferSlicer bit_positions(slicer.take(mode.omega())); + BufferSlicer offsets(slicer.take(mode.k())); + + DilithiumPolyVec hint(mode.k()); + uint8_t index = 0; + for(auto& p : hint) { + const auto end_index = offsets.take_byte(); + + // Check the bounds of the end index for this polynomial + if(end_index < index || end_index > mode.omega()) { + return std::nullopt; + } + + const auto set_bits = bit_positions.take(end_index - index); + + // Check that the set bit positions are ordered (strong unforgeability) + // TODO: explicitly add a test for this, Whycheproof perhaps? + for(size_t i = 1; i < set_bits.size(); ++i) { + if(set_bits[i] <= set_bits[i - 1]) { + return std::nullopt; + } + } + + // Set the specified bits in the polynomial + for(const auto i : set_bits) { + p[i] = 1; + } + + index = end_index; + } + + // Check that the remaining bit positions are all zero (strong unforgeability) + const auto remaining = bit_positions.take(bit_positions.remaining()); + if(!std::all_of(remaining.begin(), remaining.end(), [](auto b) { return b == 0; })) { + return std::nullopt; + } + + BOTAN_DEBUG_ASSERT(hint.ct_validate_value_range(0, 1)); + return hint; +} + +} // namespace + +/** + * NIST FIPS 204 IPD, Algorithm 16 (pkEncode) + */ +DilithiumSerializedPublicKey dilithium_encode_public_key(StrongSpan rho, + const DilithiumPolyVec& t1, + const DilithiumConstants& mode) { + DilithiumSerializedPublicKey pk(mode.public_key_bytes()); + BufferStuffer stuffer(pk); + + stuffer.append(rho); + for(const auto& p : t1) { + dilithium_poly_pack_t1(p, stuffer); + } + + BOTAN_ASSERT_NOMSG(stuffer.full()); + return pk; +} + +/** + * NIST FIPS 204 IPD, Algorithm 17 (pkDecode) + */ +std::pair dilithium_decode_public_key( + StrongSpan pk, const DilithiumConstants& mode) { + if(pk.size() != mode.public_key_bytes()) { + throw Decoding_Error("Dilithium: Invalid public key length"); + } + + BufferSlicer slicer(pk); + auto rho = slicer.copy(DilithiumConstants::SEED_RHO_BYTES); + + DilithiumPolyVec t1(mode.k()); + for(auto& p : t1) { + dilithium_poly_unpack_t1(p, slicer); + } + BOTAN_ASSERT_NOMSG(slicer.empty()); + + return {std::move(rho), std::move(t1)}; +} + +/** + * NIST FIPS 204 IPD, Algorithm 18 (skEncode) + */ +DilithiumSerializedPrivateKey dilithium_encode_private_key(StrongSpan rho, + StrongSpan tr, + StrongSpan key, + const DilithiumPolyVec& s1, + const DilithiumPolyVec& s2, + const DilithiumPolyVec& t0, + const DilithiumConstants& mode) { + DilithiumSerializedPrivateKey sk(mode.private_key_bytes()); + BufferStuffer stuffer(sk); + + stuffer.append(rho); + stuffer.append(key); + stuffer.append(tr); + + for(const auto& p : s1) { + dilithium_poly_pack_eta(p, stuffer, mode); + } + + for(const auto& p : s2) { + dilithium_poly_pack_eta(p, stuffer, mode); + } + + for(const auto& p : t0) { + dilithium_poly_pack_t0(p, stuffer); + } + + BOTAN_ASSERT_NOMSG(stuffer.full()); + return sk; +} + +/** + * NIST FIPS 204 IPD, Algorithm 19 (skDecode) + */ +std::tuple +dilithium_decode_private_key(StrongSpan sk, const DilithiumConstants& mode) { + if(sk.size() != mode.private_key_bytes()) { + throw Decoding_Error("Dilithium: Invalid private key length"); + } + + BufferSlicer slicer(sk); + + auto rho = slicer.copy(DilithiumConstants::SEED_RHO_BYTES); + auto key = slicer.copy(DilithiumConstants::SEED_SIGNING_KEY_BYTES); + auto tr = slicer.copy(DilithiumConstants::PUBLIC_KEY_HASH_BYTES); + + DilithiumPolyVec s1(mode.l()); + for(auto& p : s1) { + dilithium_poly_unpack_eta(p, slicer, mode, true /* check decoded value range */); + } + + DilithiumPolyVec s2(mode.k()); + for(auto& p : s2) { + dilithium_poly_unpack_eta(p, slicer, mode, true /* check decoded value range */); + } + + DilithiumPolyVec t0(mode.k()); + for(auto& p : t0) { + dilithium_poly_unpack_t0(p, slicer); + } + + BOTAN_ASSERT_NOMSG(slicer.empty()); + return {std::move(rho), std::move(key), std::move(tr), std::move(s1), std::move(s2), std::move(t0)}; +} + +/** + * NIST FIPS 204 IPD, Algorithm 20 (sigEncode) + */ +DilithiumSerializedSignature dilithium_encode_signature(StrongSpan c, + const DilithiumPolyVec& response, + const DilithiumPolyVec& hint, + const DilithiumConstants& mode) { + DilithiumSerializedSignature sig(mode.signature_bytes()); + BufferStuffer stuffer(sig); + + stuffer.append(c); + for(const auto& p : response) { + dilithium_poly_pack_gamma1(p, stuffer, mode); + } + dilithium_hint_pack(hint, stuffer, mode); + + return sig; +} + +/** + * NIST FIPS 204 IPD, Algorithm 21 (sigDecode) + */ +std::optional> dilithium_decode_signature( + StrongSpan sig, const DilithiumConstants& mode) { + BufferSlicer slicer(sig); + BOTAN_ASSERT_NOMSG(slicer.remaining() == mode.signature_bytes()); + + auto commitment_hash = slicer.copy(DilithiumConstants::COMMITMENT_HASH_C1_BYTES); + + DilithiumPolyVec response(mode.l()); + for(auto& p : response) { + dilithium_poly_unpack_gamma1(p, slicer, mode); + } + BOTAN_ASSERT_NOMSG(slicer.remaining() == mode.omega() + mode.k()); + + auto hint = dilithium_hint_unpack(slicer, mode); + BOTAN_ASSERT_NOMSG(slicer.empty()); + if(!hint.has_value()) { + return std::nullopt; + } + + return std::make_tuple(std::move(commitment_hash), std::move(response), std::move(hint.value())); +} + +/** + * NIST FIPS 204 IPD, Algorithm 22 (w1Encode) + */ +DilithiumSerializedCommitment dilithium_encode_commitment(const DilithiumPolyVec& w1, const DilithiumConstants& mode) { + DilithiumSerializedCommitment commitment(mode.serialized_commitment_bytes()); + BufferStuffer stuffer(commitment); + + for(const auto& p : w1) { + dilithium_poly_pack_w1(p, stuffer, mode); + } + + return commitment; +} + +/** + * NIST FIPS 204 IPD, Algorithm 23 (SampleInBall) + */ +DilithiumPoly dilithium_sample_in_ball(StrongSpan seed, const DilithiumConstants& mode) { + auto xof = mode.symmetric_primitives().H(seed); + + // This generator resembles the while loop in the spec. + auto next_byte_lower_than = [&xof](size_t i) -> uint8_t { + while(true) { + if(const uint8_t b = xof.output_next_byte(); b <= i) { + return b; + } + } + }; + + DilithiumPoly c; + uint64_t signs = load_le(xof.output<8>()); + for(size_t i = c.size() - mode.tau(); i < c.size(); ++i) { + const auto j = next_byte_lower_than(i); + c[i] = c[j]; + c[j] = 1 - 2 * (signs & 1); + signs >>= 1; + } + + BOTAN_DEBUG_ASSERT(c.ct_validate_value_range(-1, 1)); + BOTAN_DEBUG_ASSERT(c.hamming_weight() == mode.tau()); + + return c; +} + +namespace { + +/** + * NIST FIPS 204 IPD, Algorithm 24 (RejNTTPoly) + */ +void dilithium_sample_ntt_uniform(StrongSpan rho, + DilithiumPolyNTT& p, + uint16_t nonce, + const DilithiumConstants& mode) { + /** + * A generator that returns the next coefficient sampled from the XOF, + * according to: NIST FIPS 204 IPD, Algorithm 8 (CoeffFromThreeBytes). + */ + auto next_coeff = [](Botan::XOF& xof) -> uint32_t { + std::array bytes = {0}; + std::span sampling_sink_in_bytes = std::span{bytes}.first<3>(); + + while(true) { + xof.output(sampling_sink_in_bytes); + const auto z = load_le(bytes) & 0x7FFFFF; + if(z < DilithiumConstants::Q) { + return z; + } + } + }; + + auto& xof = mode.symmetric_primitives().H(rho, nonce); + for(auto& coeff : p) { + coeff = next_coeff(xof); + } + + BOTAN_DEBUG_ASSERT(p.ct_validate_value_range(0, DilithiumConstants::Q - 1)); +} + +/** + * NIST FIPS 204 IPD, Algorithm 25 (RejBoundedPoly) + */ +void dilithium_sample_uniform_eta(StrongSpan rhoprime, + DilithiumPoly& p, + uint16_t nonce, + const DilithiumConstants& mode) { + using Eta = DilithiumConstants::DilithiumEta; + + /** + * NIST FIPS 204 IPD, Algorithm 9 (CoeffFromHalfByte) + */ + auto coeff_from_halfbyte = [eta = mode.eta()](uint8_t b) -> std::optional { + BOTAN_DEBUG_ASSERT(b < 16); + + if(eta == Eta::_2 && b < 15) { + b = b - (205 * b >> 10) * 5; + return 2 - b; + } + + if(eta == Eta::_4 && b < 9) { + return 4 - b; + } + + return std::nullopt; + }; + + // A generator that returns the next coefficient sampled from the XOF. As the + // sampling uses half-bytes, this keeps track of the additionally sampled + // coefficient as needed. + auto next_coeff = [&, stashed_coeff = std::optional{}](Botan::XOF& xof) mutable -> int32_t { + if(auto stashed = std::exchange(stashed_coeff, std::nullopt)) { + return *stashed; + } + + BOTAN_DEBUG_ASSERT(!stashed_coeff.has_value()); + while(true) { + const auto b = xof.output_next_byte(); + const auto z0 = coeff_from_halfbyte(b & 0x0F); + const auto z1 = coeff_from_halfbyte(b >> 4); + + if(z0.has_value()) { + stashed_coeff = z1; // keep candidate z1 for the next invocation + return *z0; + } else if(z1.has_value()) { + // z0 was invalid, z1 is valid, nothing to stash + return *z1; + } + } + }; + + auto& xof = mode.symmetric_primitives().H(rhoprime, nonce); + for(auto& coeff : p) { + coeff = next_coeff(xof); + } + + BOTAN_DEBUG_ASSERT(p.ct_validate_value_range(-static_cast(mode.eta()), mode.eta())); +} + +} // namespace + +/** + * NIST FIPS 204 IPD, Algorithm 26 (ExpandA) + */ +DilithiumPolyMatNTT dilithium_expand_A(StrongSpan rho, const DilithiumConstants& mode) { + DilithiumPolyMatNTT A(mode.k(), mode.l()); + for(uint8_t r = 0; r < mode.k(); ++r) { + for(uint8_t s = 0; s < mode.l(); ++s) { + // In FIPS 204 IPD this is denoted as IntegerToBits(s,8)||IntegerToBits(r,8) + const uint16_t nonce = make_uint16(r, s); + dilithium_sample_ntt_uniform(rho, A[r][s], nonce, mode); + } + } + return A; +} + +/** + * NIST FIPS 204 IPD, Algorithm 27 (ExpandS) + */ +std::pair dilithium_expand_s(StrongSpan rhoprime, + const DilithiumConstants& mode) { + DilithiumPolyVec s1(mode.l()); + DilithiumPolyVec s2(mode.k()); + + uint16_t nonce = 0; + for(auto& p : s1) { + dilithium_sample_uniform_eta(rhoprime, p, nonce++, mode); + } + + for(auto& p : s2) { + dilithium_sample_uniform_eta(rhoprime, p, nonce++, mode); + } + + return {std::move(s1), std::move(s2)}; +} + +/** + * NIST FIPS 204 IPD, Algorithm 28 (ExpandMask) + */ +DilithiumPolyVec dilithium_expand_mask(StrongSpan rhoprime, + uint16_t nonce, + const DilithiumConstants& mode) { + DilithiumPolyVec s(mode.l()); + for(auto& p : s) { + auto& xof = mode.symmetric_primitives().H(rhoprime, nonce++); + dilithium_poly_unpack_gamma1(p, xof, mode); + } + return s; +} + +/** + * NIST FIPS 204 IPD, Algorithm 29 (Power2Round) + */ +std::pair dilithium_power2round(const DilithiumPolyVec& vec) { + // This procedure is taken verbatim from Dilithium's reference implementation. + auto power2round = [d = DilithiumConstants::D](int32_t r) -> std::pair { + const int32_t r1 = (r + (1 << (d - 1)) - 1) >> d; + const int32_t r0 = r - (r1 << d); + return {r1, r0}; + }; + + auto result = std::make_pair(DilithiumPolyVec(vec.size()), DilithiumPolyVec(vec.size())); + + for(size_t i = 0; i < vec.size(); ++i) { + for(size_t j = 0; j < vec[i].size(); ++j) { + std::tie(result.first[i][j], result.second[i][j]) = power2round(vec[i][j]); + } + } + + return result; +} + +namespace { + +auto dilithium_decompose_fn(const DilithiumConstants& mode) { + using Gamma2 = DilithiumConstants::DilithiumGamma2; + + // This procedure is taken verbatim from Dilithium's reference implementation. + return [gamma2 = mode.gamma2(), q = DilithiumConstants::Q](int32_t r) -> std::pair { + int32_t r1 = (r + 127) >> 7; + + switch(gamma2) { + case Gamma2::Qminus1DevidedBy32: + r1 = (r1 * 1025 + (1 << 21)) >> 22; + r1 &= 15; + break; + case Gamma2::Qminus1DevidedBy88: + r1 = (r1 * 11275 + (1 << 23)) >> 24; + r1 ^= is_negative_mask(43 - r1).if_set_return(r1); + break; + } + + int32_t r0 = r - r1 * 2 * gamma2; + r0 -= is_negative_mask((q - 1) / 2 - r0).if_set_return(q); + + return {r1, r0}; + }; +} + +} // namespace + +/** + * NIST FIPS 204 IPD, Algorithm 30 (Decompose) + * + * Algorithms 31 (HighBits) and 32 (LowBits) are not implemented explicitly, + * simply use the first (HighBits) and second (LowBits) element of the result. + */ +std::pair dilithium_decompose(const DilithiumPolyVec& vec, + const DilithiumConstants& mode) { + auto decompose = dilithium_decompose_fn(mode); + auto result = std::make_pair(DilithiumPolyVec(vec.size()), DilithiumPolyVec(vec.size())); + + for(size_t i = 0; i < vec.size(); ++i) { + for(size_t j = 0; j < vec[i].size(); ++j) { + std::tie(result.first[i][j], result.second[i][j]) = decompose(vec[i][j]); + } + } + + return result; +} + +/** + * NIST FIPS 204 IPD, Algorithm 33 (MakeHint) + */ +DilithiumPolyVec dilithium_make_hint(const DilithiumPolyVec& z, + const DilithiumPolyVec& r, + const DilithiumConstants& mode) { + BOTAN_DEBUG_ASSERT(z.size() == r.size()); + + auto make_hint = [gamma2 = int32_t(mode.gamma2()), q_gamma2 = DilithiumConstants::Q - int32_t(mode.gamma2())]( + int32_t c0, int32_t c1) -> bool { + if(c0 <= gamma2 || c0 > q_gamma2 || (c0 == q_gamma2 && c1 == 0)) { + return false; + } + return true; + }; + + DilithiumPolyVec hint(r.size()); + + for(size_t i = 0; i < r.size(); ++i) { + for(size_t j = 0; j < r[i].size(); ++j) { + hint[i][j] = make_hint(z[i][j], r[i][j]); + } + } + + BOTAN_DEBUG_ASSERT(hint.ct_validate_value_range(0, 1)); + + return hint; +} + +/** + * NIST FIPS 204 IPD, Algorithm 34 (UseHint) + */ +void dilithium_use_hint(DilithiumPolyVec& vec, const DilithiumPolyVec& hints, const DilithiumConstants& mode) { + using Gamma2 = DilithiumConstants::DilithiumGamma2; + + BOTAN_DEBUG_ASSERT(hints.size() == vec.size()); + BOTAN_DEBUG_ASSERT(hints.ct_validate_value_range(0, 1)); + BOTAN_DEBUG_ASSERT(vec.ct_validate_value_range(0, DilithiumConstants::Q - 1)); + + auto use_hint = [gamma2 = mode.gamma2(), decompose = dilithium_decompose_fn(mode)](bool hint, int32_t c) -> int32_t { + auto [a1, a0] = decompose(c); + + if(!hint) { + return a1; + } + + switch(gamma2) { + case Gamma2::Qminus1DevidedBy32: + BOTAN_DEBUG_ASSERT(a1 > -16 && a1 < 16); + return (a0 > 0) ? (a1 + 1) & 15 : (a1 - 1) & 15; + case Gamma2::Qminus1DevidedBy88: + BOTAN_DEBUG_ASSERT(a1 > -44 && a1 < 44); + // NOLINTNEXTLINE(*-avoid-nested-conditional-operator) + return (a0 > 0) ? ((a1 == 43) ? 0 : a1 + 1) : ((a1 == 0) ? 43 : a1 - 1); + } + + BOTAN_ASSERT_UNREACHABLE(); + }; + + for(size_t i = 0; i < vec.size(); ++i) { + for(size_t j = 0; j < vec[i].size(); ++j) { + vec[i][j] = use_hint(hints[i][j], vec[i][j]); + } + } + + BOTAN_DEBUG_ASSERT(vec.ct_validate_value_range(0, (DilithiumConstants::Q - 1) / (2 * mode.gamma2()))); +} + +bool dilithium_infinity_norm_within_bound(const DilithiumPolyVec& vec, size_t bound) { + BOTAN_DEBUG_ASSERT(bound <= (DilithiumConstants::Q - 1) / 8); + + // It is ok to leak which coefficient violates the bound as the probability + // for each coefficient is independent of secret data but we must not leak + // the sign of the centralized representative. + for(const auto& p : vec) { + for(auto c : p) { + const auto abs_c = c - is_negative_mask(c).if_set_return(2 * c); + if(abs_c >= bound) { + return false; + } + } + } + + return true; +} + +} // namespace Botan diff --git a/src/lib/pubkey/dilithium/dilithium_common/dilithium_algos.h b/src/lib/pubkey/dilithium/dilithium_common/dilithium_algos.h new file mode 100644 index 00000000000..692b20839ea --- /dev/null +++ b/src/lib/pubkey/dilithium/dilithium_common/dilithium_algos.h @@ -0,0 +1,76 @@ +/* + * Crystals Dilithium Internal Algorithms + * + * (C) 2021-2024 Jack Lloyd + * (C) 2021-2022 Manuel Glaser and Michael Boric, Rohde & Schwarz Cybersecurity + * (C) 2021-2022 René Meusel and Hannes Rantzsch, neXenio GmbH + * (C) 2024 René Meusel, Rohde & Schwarz Cybersecurity + * + * Botan is released under the Simplified BSD License (see license.txt) + */ + +#ifndef BOTAN_DILITHIUM_ALGOS_H_ +#define BOTAN_DILITHIUM_ALGOS_H_ + +#include + +namespace Botan { + +DilithiumPolyMatNTT dilithium_expand_A(StrongSpan rho, const DilithiumConstants& mode); +std::pair dilithium_expand_s(StrongSpan rhoprime, + const DilithiumConstants& mode); +DilithiumPolyVec dilithium_expand_mask(StrongSpan rhoprime, + uint16_t nonce, + const DilithiumConstants& mode); + +DilithiumSerializedCommitment dilithium_encode_commitment(const DilithiumPolyVec& w1, const DilithiumConstants& mode); + +DilithiumPoly dilithium_sample_in_ball(StrongSpan seed, const DilithiumConstants& mode); + +std::optional> dilithium_decode_signature( + StrongSpan sig, const DilithiumConstants& mode); + +DilithiumSerializedSignature dilithium_encode_signature(StrongSpan c, + const DilithiumPolyVec& response, + const DilithiumPolyVec& hint, + const DilithiumConstants& mode); + +DilithiumSerializedPublicKey dilithium_encode_public_key(StrongSpan rho, + const DilithiumPolyVec& t1, + const DilithiumConstants& mode); + +std::pair dilithium_decode_public_key( + StrongSpan pk, const DilithiumConstants& mode); + +DilithiumSerializedPrivateKey dilithium_encode_private_key(StrongSpan rho, + StrongSpan tr, + StrongSpan key, + const DilithiumPolyVec& s1, + const DilithiumPolyVec& s2, + const DilithiumPolyVec& t0, + const DilithiumConstants& mode); + +std::tuple +dilithium_decode_private_key(StrongSpan sk, const DilithiumConstants& mode); + +std::pair dilithium_power2round(const DilithiumPolyVec& vec); + +std::pair dilithium_decompose(const DilithiumPolyVec& vec, + const DilithiumConstants& mode); + +DilithiumPolyVec dilithium_make_hint(const DilithiumPolyVec& z, + const DilithiumPolyVec& r, + const DilithiumConstants& mode); + +void dilithium_use_hint(DilithiumPolyVec& vec, const DilithiumPolyVec& hints, const DilithiumConstants& mode); + +bool dilithium_infinity_norm_within_bound(const DilithiumPolyVec& vec, size_t bound); + +} // namespace Botan + +#endif diff --git a/src/lib/pubkey/dilithium/dilithium_common/dilithium_constants.cpp b/src/lib/pubkey/dilithium/dilithium_common/dilithium_constants.cpp new file mode 100644 index 00000000000..18cfa440813 --- /dev/null +++ b/src/lib/pubkey/dilithium/dilithium_common/dilithium_constants.cpp @@ -0,0 +1,73 @@ +/** + * Asymmetric primitives for dilithium + * + * (C) 2022-2023 Jack Lloyd + * (C) 2022-2024 Michael Boric, René Meusel - Rohde & Schwarz Cybersecurity + * (C) 2022 Manuel Glaser - Rohde & Schwarz Cybersecurity + * + * Botan is released under the Simplified BSD License (see license.txt) + */ + +#include + +#include + +namespace Botan { + +DilithiumConstants::DilithiumConstants(DilithiumMode mode) : m_mode(mode) { + switch(m_mode.mode()) { + case Botan::DilithiumMode::Dilithium4x4: + case Botan::DilithiumMode::Dilithium4x4_AES: + m_tau = DilithiumTau::_39; + m_lambda = DilithiumLambda::_128; + m_gamma1 = DilithiumGamma1::ToThe17th; + m_gamma2 = DilithiumGamma2::Qminus1DevidedBy88; + m_k = 4; + m_l = 4; + m_eta = DilithiumEta::_2; + m_beta = DilithiumBeta::_78; + m_omega = DilithiumOmega::_80; + break; + case Botan::DilithiumMode::Dilithium6x5: + case Botan::DilithiumMode::Dilithium6x5_AES: + m_tau = DilithiumTau::_49; + m_lambda = DilithiumLambda::_192; + m_gamma1 = DilithiumGamma1::ToThe19th; + m_gamma2 = DilithiumGamma2::Qminus1DevidedBy32; + m_k = 6; + m_l = 5; + m_eta = DilithiumEta::_4; + m_beta = DilithiumBeta::_196; + m_omega = DilithiumOmega::_55; + break; + case Botan::DilithiumMode::Dilithium8x7: + case Botan::DilithiumMode::Dilithium8x7_AES: + m_tau = DilithiumTau::_60; + m_lambda = DilithiumLambda::_256; + m_gamma1 = DilithiumGamma1::ToThe19th; + m_gamma2 = DilithiumGamma2::Qminus1DevidedBy32; + m_k = 8; + m_l = 7; + m_eta = DilithiumEta::_2; + m_beta = DilithiumBeta::_120; + m_omega = DilithiumOmega::_75; + break; + } + + const auto s1_bytes = 32 * m_l * bitlen(2 * m_eta); + const auto s2_bytes = 32 * m_k * bitlen(2 * m_eta); + const auto t0_bytes = 32 * m_k * D; + const auto t1_bytes = 32 * m_k * (bitlen(static_cast(Q) - 1) - D); + const auto z_bytes = 32 * m_l * (1 + bitlen(m_gamma1 - 1)); + const auto hint_bytes = m_omega + m_k; + + m_private_key_bytes = + SEED_RHO_BYTES + SEED_SIGNING_KEY_BYTES + PUBLIC_KEY_HASH_BYTES + s1_bytes + s2_bytes + t0_bytes; + m_public_key_bytes = SEED_RHO_BYTES + t1_bytes; + m_signature_bytes = COMMITMENT_HASH_FULL_BYTES + z_bytes + hint_bytes; + m_serialized_commitment_bytes = 32 * m_k * bitlen(((Q - 1) / (2 * m_gamma2)) - 1); + + m_symmetric_primitives = Dilithium_Symmetric_Primitives::create(*this); +} + +} // namespace Botan diff --git a/src/lib/pubkey/dilithium/dilithium_common/dilithium_constants.h b/src/lib/pubkey/dilithium/dilithium_common/dilithium_constants.h new file mode 100644 index 00000000000..e8be9173e04 --- /dev/null +++ b/src/lib/pubkey/dilithium/dilithium_common/dilithium_constants.h @@ -0,0 +1,171 @@ +/* + * Crystals Dilithium Constants + * + * (C) 2022-2023 Jack Lloyd + * (C) 2022 Manuel Glaser - Rohde & Schwarz Cybersecurity + * (C) 2022-2023 Michael Boric, René Meusel - Rohde & Schwarz Cybersecurity + * (C) 2024 René Meusel, Rohde & Schwarz Cybersecurity + * + * Botan is released under the Simplified BSD License (see license.txt) + */ + +#ifndef BOTAN_DILITHIUM_CONSTANTS_H_ +#define BOTAN_DILITHIUM_CONSTANTS_H_ + +#include + +namespace Botan { + +class Dilithium_Symmetric_Primitives; + +/** + * Algorithm constants and parameter-set dependent values + */ +class DilithiumConstants final { + public: + /// base data type for most calculations + using T = int32_t; + + /// number of coefficients in a polynomial + static constexpr T N = 256; + + /// modulus + static constexpr T Q = 8380417; + + /// number of dropped bits from t (see FIPS 204 Section 5) + static constexpr T D = 13; + + /// as specified in FIPS 204 (see Algorithm 36 (NTT^-1), f = 256^-1 mod Q) + static constexpr T F = 8347681; + + /// the 512-th root of unity modulo Q (see FIPS 204 Section 8.5) + static constexpr T ROOT_OF_UNITY = 1753; + + /// degree of the NTT polynomials + static constexpr size_t NTT_Degree = 256; + + public: + /// \name Byte length's of various hash outputs and seeds + /// @{ + + static constexpr size_t SEED_RANDOMNESS_BYTES = 32; + static constexpr size_t SEED_RHO_BYTES = 32; + static constexpr size_t SEED_RHOPRIME_BYTES = 64; + static constexpr size_t SEED_SIGNING_KEY_BYTES = 32; + static constexpr size_t MESSAGE_HASH_BYTES = 64; + static constexpr size_t PUBLIC_KEY_HASH_BYTES = 32; + static constexpr size_t COMMITMENT_HASH_FULL_BYTES = 32; + static constexpr size_t COMMITMENT_HASH_C1_BYTES = 32; + + /// @} + + public: + enum DilithiumTau : uint32_t { _39 = 39, _49 = 49, _60 = 60 }; + + enum DilithiumLambda : uint32_t { _128 = 128, _192 = 192, _256 = 256 }; + + enum DilithiumGamma1 : uint32_t { ToThe17th = (1 << 17), ToThe19th = (1 << 19) }; + + enum DilithiumGamma2 : uint32_t { Qminus1DevidedBy88 = (Q - 1) / 88, Qminus1DevidedBy32 = (Q - 1) / 32 }; + + enum DilithiumEta : uint32_t { _2 = 2, _4 = 4 }; + + enum DilithiumBeta : uint32_t { _78 = 78, _196 = 196, _120 = 120 }; + + enum DilithiumOmega : uint32_t { _80 = 80, _55 = 55, _75 = 75 }; + + DilithiumConstants(DilithiumMode dimension); + ~DilithiumConstants() = default; + + DilithiumConstants(const DilithiumConstants& other) : DilithiumConstants(other.m_mode) {} + + DilithiumConstants(DilithiumConstants&& other) = default; + DilithiumConstants& operator=(const DilithiumConstants& other) = delete; + DilithiumConstants& operator=(DilithiumConstants&& other) = default; + + bool is_modern() const { return m_mode.is_modern(); } + + bool is_aes() const { return m_mode.is_aes(); } + + public: + /// \name Foundational constants + /// @{ + + /// hamming weight of the polynomial 'c' sampled from the commitment's hash + DilithiumTau tau() const { return m_tau; } + + /// collision strength of the commitment hash function + DilithiumLambda lambda() const { return m_lambda; } + + /// coefficient range of the randomly sampled mask 'y' + DilithiumGamma1 gamma1() const { return m_gamma1; } + + /// low-order rounding range for decomposing the commitment from polynomial vector 'w' + DilithiumGamma2 gamma2() const { return m_gamma2; } + + /// dimensions of the expanded matrix A + uint8_t k() const { return m_k; } + + /// dimensions of the expanded matrix A + uint8_t l() const { return m_l; } + + /// coefficient range of the private key's polynomial vectors 's1' and 's2' + DilithiumEta eta() const { return m_eta; } + + /// tau * eta + DilithiumBeta beta() const { return m_beta; } + + /// maximal hamming weight of the hint polynomial vector 'h' + DilithiumOmega omega() const { return m_omega; } + + /// length of the entire commitment hash in bytes + size_t commitment_hash_full_bytes() const { return COMMITMENT_HASH_FULL_BYTES; } + + /// @} + + /// \name Sizes of encoded data structures + /// @{ + + /// byte length of the encoded signature + size_t signature_bytes() const { return m_signature_bytes; } + + /// byte length of the encoded public key + size_t public_key_bytes() const { return m_public_key_bytes; } + + /// byte length of the encoded private key + size_t private_key_bytes() const { return m_private_key_bytes; } + + /// byte length of the packed commitment polynomial vector 'w1' + size_t serialized_commitment_bytes() const { return m_serialized_commitment_bytes; } + + /// @} + + DilithiumMode mode() const { return m_mode; } + + Dilithium_Symmetric_Primitives& symmetric_primitives() const { return *m_symmetric_primitives; } + + private: + DilithiumMode m_mode; + + DilithiumTau m_tau; + DilithiumLambda m_lambda; + DilithiumGamma1 m_gamma1; + DilithiumGamma2 m_gamma2; + uint8_t m_k; + uint8_t m_l; + DilithiumEta m_eta; + DilithiumBeta m_beta; + DilithiumOmega m_omega; + + uint32_t m_private_key_bytes; + uint32_t m_public_key_bytes; + uint32_t m_signature_bytes; + uint32_t m_serialized_commitment_bytes; + + // Mode dependent primitives + std::unique_ptr m_symmetric_primitives; +}; + +} // namespace Botan + +#endif diff --git a/src/lib/pubkey/dilithium/dilithium_common/dilithium_polynomial.h b/src/lib/pubkey/dilithium/dilithium_common/dilithium_polynomial.h new file mode 100644 index 00000000000..319d64eb266 --- /dev/null +++ b/src/lib/pubkey/dilithium/dilithium_common/dilithium_polynomial.h @@ -0,0 +1,119 @@ +/* + * Crystals Dilithium Polynomial Adapter + * + * (C) 2022-2023 Jack Lloyd + * (C) 2022 Manuel Glaser - Rohde & Schwarz Cybersecurity + * (C) 2022-2023 Michael Boric, René Meusel - Rohde & Schwarz Cybersecurity + * (C) 2024 René Meusel, Rohde & Schwarz Cybersecurity + * + * Botan is released under the Simplified BSD License (see license.txt) + */ + +#ifndef BOTAN_DILITHIUM_POLYNOMIAL_H_ +#define BOTAN_DILITHIUM_POLYNOMIAL_H_ + +#include +#include +#include +#include + +namespace Botan { + +class DilithiumPolyTraits final : public CRYSTALS::Trait_Base { + private: + friend class CRYSTALS::Trait_Base; + + /** + * NIST FIPS 204 IPD, Algorithm 37 (Montgomery_Reduce) + */ + static constexpr T montgomery_reduce_coefficient(T2 a) { + const T2 t = static_cast(static_cast(static_cast(a)) * Q_inverse); + return (a - static_cast(t) * Q) >> (sizeof(T) * 8); + } + + static constexpr T barrett_reduce_coefficient(T a) { + // 2**22 is roughly Q/2 and 2**23 is roughly Q + T t = (a + (1 << 22)) >> 23; + a = a - t * Q; + return a; + } + + public: + /** + * NIST FIPS 204 IPD, Algorithm 35 (NTT) + * + * Note: ntt(), inverse_ntt() and operator* have side effects on the + * montgomery factor of the involved coefficients! + * It is assumed that EXACTLY ONE vector or matrix multiplication + * is performed between transforming in and out of NTT domain. + * + * Produces the result of the NTT transformation without any montgomery + * factors in the coefficients. + */ + static constexpr void ntt(std::span coeffs) { + size_t j; + size_t k = 0; + + for(size_t len = N / 2; len > 0; len >>= 1) { + for(size_t start = 0; start < N; start = j + len) { + const T zeta = zetas[++k]; + for(j = start; j < start + len; ++j) { + // Zetas contain the montgomery parameter 2^32 mod q + T t = fqmul(zeta, coeffs[j + len]); + coeffs[j + len] = coeffs[j] - t; + coeffs[j] = coeffs[j] + t; + } + } + } + } + + /** + * NIST FIPS 204 IPD, Algorithm 36 (NTT^-1). + * + * The output is effectively multiplied by the montgomery parameter 2^32 + * mod q so that the input factors 2^(-32) mod q are eliminated. Note + * that factors 2^(-32) mod q are introduced by multiplication and + * reduction of values not in montgomery domain. + * + * Produces the result of the inverse NTT transformation with a montgomery + * factor of (2^32 mod q) added (!). See above. + */ + static constexpr void inverse_ntt(std::span coeffs) { + size_t j; + size_t k = N; + for(size_t len = 1; len < N; len <<= 1) { + for(size_t start = 0; start < N; start = j + len) { + const T zeta = -zetas[--k]; + for(j = start; j < start + len; ++j) { + T t = coeffs[j]; + coeffs[j] = t + coeffs[j + len]; + coeffs[j + len] = t - coeffs[j + len]; + // Zetas contain the montgomery parameter 2^32 mod q + coeffs[j + len] = fqmul(zeta, coeffs[j + len]); + } + } + } + + for(auto& coeff : coeffs) { + coeff = fqmul(coeff, F_WITH_MONTY_SQUARED); + } + } + + /** + * Multiplication of two polynomials @p lhs and @p rhs in NTT domain. + * + * Produces the result of the multiplication in NTT domain, with a factor + * of (2^-32 mod q) in each element due to montgomery reduction. + */ + static constexpr void poly_pointwise_montgomery(std::span result, + std::span lhs, + std::span rhs) { + for(size_t i = 0; i < N; ++i) { + result[i] = fqmul(lhs[i], rhs[i]); + } + } +}; + +} // namespace Botan + +#endif diff --git a/src/lib/pubkey/dilithium/dilithium_common/dilithium_polynomials.h b/src/lib/pubkey/dilithium/dilithium_common/dilithium_polynomials.h deleted file mode 100644 index 569c67d8d92..00000000000 --- a/src/lib/pubkey/dilithium/dilithium_common/dilithium_polynomials.h +++ /dev/null @@ -1,1437 +0,0 @@ -/* -* Crystals Dilithium Digital Signature Algorithms -* Based on the public domain reference implementation by the -* designers (https://github.com/pq-crystals/dilithium) -* -* Further changes -* (C) 2021-2023 Jack Lloyd -* (C) 2021-2022 Manuel Glaser - Rohde & Schwarz Cybersecurity -* (C) 2021-2023 Michael Boric, René Meusel - Rohde & Schwarz Cybersecurity -* -* Botan is released under the Simplified BSD License (see license.txt) -*/ - -#ifndef BOTAN_DILITHIUM_POLYNOMIAL_H_ -#define BOTAN_DILITHIUM_POLYNOMIAL_H_ - -#include - -#include -#include - -#include -#include -#include - -namespace Botan::Dilithium { - -class Polynomial { - public: - // public member is on purpose - std::array m_coeffs; - - /** - * Adds two polynomials element-wise. Does not perform a reduction after the addition. - * Therefore this operation might cause an integer overflow. - */ - Polynomial& operator+=(const Polynomial& other) { - for(size_t i = 0; i < this->m_coeffs.size(); ++i) { - this->m_coeffs[i] = this->m_coeffs[i] + other.m_coeffs[i]; - } - return *this; - } - - /** - * Subtracts two polynomials element-wise. Does not perform a reduction after the subtraction. - * Therefore this operation might cause an integer underflow. - */ - Polynomial& operator-=(const Polynomial& other) { - for(size_t i = 0; i < this->m_coeffs.size(); ++i) { - this->m_coeffs[i] = this->m_coeffs[i] - other.m_coeffs[i]; - } - return *this; - } - - /*************************************************** - * Name: rej_uniform - * - * Description: Sample uniformly random coefficients in [0, Q-1] by - * performing rejection sampling on array of random bytes. - * - * Arguments: - Polynomial& a: reference to output array (allocated) - * - size_t position: starting point - * - size_t len: number of coefficients to be sampled - * - const uint8_t *buf: array of random bytes - * - size_t buflen: length of array of random bytes - * - * Returns number of sampled coefficients. Can be smaller than len if not enough - * random bytes were given. - **************************************************/ - static size_t rej_uniform(Polynomial& p, size_t position, size_t len, const uint8_t* buf, size_t buflen) { - size_t ctr = 0, pos = 0; - while(ctr < len && pos + 3 <= buflen) { - uint32_t t = buf[pos++]; - t |= static_cast(buf[pos++]) << 8; - t |= static_cast(buf[pos++]) << 16; - t &= 0x7FFFFF; - - if(t < DilithiumModeConstants::Q) { - p.m_coeffs[position + ctr++] = static_cast(t); - } - } - return ctr; - } - - /************************************************* - * Name: rej_eta - * - * Description: Sample uniformly random coefficients in [-ETA, ETA] by - * performing rejection sampling on array of random bytes. - * - * Arguments: - Polynomial &a: pointer to output array (allocated) - * - size_t offset: starting point for the output polynomial - * - size_t len: number of coefficients to be sampled - * - const secure_vector& buf: sv reference of random bytes - * - size_t buflen: length of array of random bytes - * - const DilithiumModeConstants& - * - * Returns number of sampled coefficients. Can be smaller than len if not enough - * random bytes were given. - **************************************************/ - static size_t rej_eta(Polynomial& a, - size_t offset, - size_t len, - const secure_vector& buf, - size_t buflen, - const DilithiumModeConstants& mode) { - size_t ctr = 0, pos = 0; - while(ctr < len && pos < buflen) { - uint32_t t0 = buf[pos] & 0x0F; - uint32_t t1 = buf[pos++] >> 4; - - switch(mode.eta()) { - case DilithiumEta::Eta2: { - if(t0 < 15) { - t0 = t0 - (205 * t0 >> 10) * 5; - a.m_coeffs[offset + ctr++] = 2 - t0; - } - if(t1 < 15 && ctr < len) { - t1 = t1 - (205 * t1 >> 10) * 5; - a.m_coeffs[offset + ctr++] = 2 - t1; - } - } break; - case DilithiumEta::Eta4: { - if(t0 < 9) { - a.m_coeffs[offset + ctr++] = 4 - t0; - } - if(t1 < 9 && ctr < len) { - a.m_coeffs[offset + ctr++] = 4 - t1; - } - } break; - } - } - return ctr; - } - - /************************************************* - * Name: fill_poly_uniform_eta - * - * Description: Sample polynomial with uniformly random coefficients - * in [-ETA,ETA] by performing rejection sampling on the - * output stream from SHAKE256(seed|nonce) or AES256CTR(seed,nonce). - * - * Arguments: - Polynomial& a: reference to output polynomial - * - const uint8_t seed[]: byte array with seed of length CRHBYTES - * - uint16_t nonce: 2-byte nonce - * - const DilithiumModeConstants& mode: Mode dependent values. - **************************************************/ - static void fill_poly_uniform_eta(Polynomial& a, - const secure_vector& seed, - uint16_t nonce, - const DilithiumModeConstants& mode) { - BOTAN_ASSERT_NOMSG(seed.size() == DilithiumModeConstants::CRHBYTES); - - auto xof = mode.XOF_256(seed, nonce); - - secure_vector buf(mode.poly_uniform_eta_nblocks() * mode.stream256_blockbytes()); - xof->output(buf); - size_t ctr = Polynomial::rej_eta(a, 0, DilithiumModeConstants::N, buf, buf.size(), mode); - - while(ctr < DilithiumModeConstants::N) { - xof->output(std::span(buf).first(mode.stream256_blockbytes())); - ctr += Polynomial::rej_eta(a, ctr, DilithiumModeConstants::N - ctr, buf, mode.stream256_blockbytes(), mode); - } - } - - /************************************************* - * Name: power2round - * - * Description: For finite field element a, compute a0, a1 such that - * a mod^+ Q = a1*2^D + a0 with -2^{D-1} < a0 <= 2^{D-1}. - * Assumes a to be standard representative. - * - * Arguments: - int32_t a: input element - * - int32_t *a0: pointer to output element a0 - * - * Returns a1. - **************************************************/ - static int32_t power2round(int32_t& a0, int32_t a) { - int32_t a1 = (a + (1 << (DilithiumModeConstants::D - 1)) - 1) >> DilithiumModeConstants::D; - a0 = a - (a1 << DilithiumModeConstants::D); - return a1; - } - - /************************************************* - * Name: fill_polys_power2round - * - * Description: For all coefficients c of the input polynomial, - * compute c0, c1 such that c mod Q = c1*2^D + c0 - * with -2^{D-1} < c0 <= 2^{D-1}. Assumes coefficients to be - * standard representatives. - * - * Arguments: - Polynomial& a1: pointer to output polynomial with coefficients c1 - * - Polynomial& a0: pointer to output polynomial with coefficients c0 - * - const Polynomial& a: pointer to input polynomial - **************************************************/ - static void fill_polys_power2round(Polynomial& a1, Polynomial& a0, const Polynomial& a) { - for(size_t i = 0; i < DilithiumModeConstants::N; ++i) { - a1.m_coeffs[i] = Polynomial::power2round(a0.m_coeffs[i], a.m_coeffs[i]); - } - } - - /************************************************* - * Name: challenge - * - * Description: Implementation of H. Samples polynomial with TAU nonzero - * coefficients in {-1,1} using the output stream of - * SHAKE256(seed). - * - * Arguments: - Polynomial &c: pointer to output polynomial - * - const uint8_t mu[]: byte array containing seed of length SEEDBYTES - * - const DilithiumModeConstants& mode: reference to dilihtium mode values - **************************************************/ - static Polynomial poly_challenge(const uint8_t* seed, const DilithiumModeConstants& mode) { - Polynomial c; - - SHAKE_256 shake256_hasher(DilithiumModeConstants::SHAKE256_RATE * 8); - shake256_hasher.update(seed, DilithiumModeConstants::SEEDBYTES); - auto buf = shake256_hasher.final(); - - uint64_t signs = 0; - for(size_t i = 0; i < 8; ++i) { - signs |= static_cast(buf[i]) << 8 * i; - } - size_t pos = 8; - - for(size_t i = 0; i < DilithiumModeConstants::N; ++i) { - c.m_coeffs[i] = 0; - } - for(size_t i = DilithiumModeConstants::N - mode.tau(); i < DilithiumModeConstants::N; ++i) { - size_t b; - do { - b = buf[pos++]; - } while(b > i); - - c.m_coeffs[i] = c.m_coeffs[b]; - c.m_coeffs[b] = 1 - 2 * (signs & 1); - signs >>= 1; - } - return c; - } - - /************************************************* - * Name: poly_chknorm - * - * Description: Check infinity norm of polynomial against given bound. - * Assumes input coefficients were reduced by reduce32(). - * - * Arguments: - const Polynomial& a: pointer to polynomial - * - size_t B: norm bound - * - * Returns false if norm is strictly smaller than B <= (Q-1)/8 and true otherwise. - **************************************************/ - static bool poly_chknorm(const Polynomial& a, size_t B) { - if(B > (DilithiumModeConstants::Q - 1) / 8) { - return true; - } - - /* It is ok to leak which coefficient violates the bound since - the probability for each coefficient is independent of secret - data but we must not leak the sign of the centralized representative. */ - for(const auto& coeff : a.m_coeffs) { - /* Absolute value */ - size_t t = coeff >> 31; - t = coeff - (t & 2 * coeff); - - if(t >= B) { - return true; - } - } - return false; - } - - /************************************************* - * Name: make_hint - * - * Description: Compute hint bit indicating whether the low bits of the - * input element overflow into the high bits. Inputs assumed - * to be standard representatives. - * - * Arguments: - size_t a0: low bits of input element - * - size_t a1: high bits of input element - * - const DilithiumModeConstants& mode: reference to dilihtium mode values - * - * Returns 1 if overflow. - **************************************************/ - static int32_t make_hint(size_t a0, size_t a1, const DilithiumModeConstants& mode) { - const auto gamma2 = mode.gamma2(); - const auto Q_gamma2 = DilithiumModeConstants::Q - gamma2; - if(a0 <= gamma2 || a0 > Q_gamma2 || (a0 == Q_gamma2 && a1 == 0)) { - return 0; - } - return 1; - } - - /************************************************* - * Name: generate_hint_polynomial - * - * Description: Compute hint polynomial. The coefficients of which indicate - * whether the low bits of the corresponding coefficient of - * the input polynomial overflow into the high bits. - * - * Arguments: - Polynomial& h: reference to output hint polynomial - * - const Polynomial& a0: reference to low part of input polynomial - * - const Polynomial& a1: reference to high part of input polynomial - * - const DilithiumModeConstants& mode: reference to dilihtium mode values - * - * Returns number of 1 bits. - **************************************************/ - static size_t generate_hint_polynomial(Polynomial& h, - const Polynomial& a0, - const Polynomial& a1, - const DilithiumModeConstants& mode) { - size_t s = 0; - - for(size_t i = 0; i < DilithiumModeConstants::N; ++i) { - h.m_coeffs[i] = Polynomial::make_hint(a0.m_coeffs[i], a1.m_coeffs[i], mode); - s += h.m_coeffs[i]; - } - - return s; - } - - /************************************************* - * Name: decompose - * - * Description: For finite field element a, compute high and low bits a0, a1 such - * that a mod^+ Q = a1*ALPHA + a0 with -ALPHA/2 < a0 <= ALPHA/2 except - * if a1 = (Q-1)/ALPHA where we set a1 = 0 and - * -ALPHA/2 <= a0 = a mod^+ Q - Q < 0. Assumes a to be standard - * representative. - * - * Arguments: - int32_t a: input element - * - int32_t *a0: pointer to output element a0 - * - const DilithiumModeConstants& mode: reference to dilihtium mode values - * - * Returns a1. - **************************************************/ - static int32_t decompose(int32_t* a0, int32_t a, const DilithiumModeConstants& mode) { - int32_t a1 = (a + 127) >> 7; - if(mode.gamma2() == (DilithiumModeConstants::Q - 1) / 32) { - a1 = (a1 * 1025 + (1 << 21)) >> 22; - a1 &= 15; - } else { - BOTAN_ASSERT_NOMSG(mode.gamma2() == (DilithiumModeConstants::Q - 1) / 88); - a1 = (a1 * 11275 + (1 << 23)) >> 24; - a1 ^= ((43 - a1) >> 31) & a1; - } - - *a0 = a - a1 * 2 * static_cast(mode.gamma2()); - *a0 -= (((DilithiumModeConstants::Q - 1) / 2 - *a0) >> 31) & DilithiumModeConstants::Q; - return a1; - } - - /************************************************* - * Name: use_hint - * - * Description: Correct high bits according to hint. - * - * Arguments: - int32_t a: input element - * - size_t hint: hint bit - * - const DilithiumModeConstants& mode: reference to dilihtium mode values - * - * Returns corrected high bits. - **************************************************/ - static int32_t use_hint(int32_t a, size_t hint, const DilithiumModeConstants& mode) { - int32_t a0; - - int32_t a1 = Polynomial::decompose(&a0, a, mode); - if(hint == 0) { - return a1; - } - - if(mode.gamma2() == ((DilithiumModeConstants::Q - 1) / 32)) { - if(a0 > 0) { - return (a1 + 1) & 15; - } else { - return (a1 - 1) & 15; - } - } else { - if(a0 > 0) { - return (a1 == 43) ? 0 : a1 + 1; - } else { - return (a1 == 0) ? 43 : a1 - 1; - } - } - } - - /************************************************* - * Name: poly_use_hint - * - * Description: Use hint polynomial to correct the high bits of a polynomial. - * - * Arguments: - Polynomial& b: reference to output polynomial with corrected high bits - * - const Polynomial& a: reference to input polynomial - * - const Polynomial& h: reference to input hint polynomial - * - const DilithiumModeConstants& mode: reference to dilihtium mode values - * - **************************************************/ - static void poly_use_hint(Polynomial& b, - const Polynomial& a, - const Polynomial& h, - const DilithiumModeConstants& mode) { - for(size_t i = 0; i < DilithiumModeConstants::N; ++i) { - b.m_coeffs[i] = Polynomial::use_hint(a.m_coeffs[i], h.m_coeffs[i], mode); - } - } - - /************************************************* - * Name: montgomery_reduce - * - * Description: For finite field element a with -2^{31}Q <= a <= Q*2^31, - * compute r \equiv a*2^{-32} (mod Q) such that -Q < r < Q. - * - * Arguments: - int64_t: finite field element a - * - * Returns r. - **************************************************/ - int32_t montgomery_reduce(int64_t a) const { - int32_t t = static_cast(static_cast(static_cast(a)) * DilithiumModeConstants::QINV); - t = (a - static_cast(t) * DilithiumModeConstants::Q) >> 32; - return t; - } - - /************************************************* - * Name: poly_pointwise_montgomery - * - * Description: Pointwise multiplication of polynomials in NTT domain - * representation and multiplication of resulting polynomial - * by 2^{-32}. - * For finite field element a with -2^{31}Q <= a <= Q*2^31, - * compute r \equiv a*2^{-32} (mod Q) such that -Q < r < Q. - * - * Arguments: - Polynomial& c: reference to output polynomial - * - const Polynomial& a: reference to first input polynomial - * - const Polynomial& b: reference to second input polynomial - **************************************************/ - void poly_pointwise_montgomery(Polynomial& output, const Polynomial& second) const { - for(size_t i = 0; i < DilithiumModeConstants::N; ++i) { - output.m_coeffs[i] = montgomery_reduce(static_cast(m_coeffs[i]) * second.m_coeffs[i]); - } - } - - /************************************************* - * Name: ntt - * - * Description: Forward NTT, in-place. No modular reduction is performed after - * additions or subtractions. Output vector is in bitreversed order. - * - * Arguments: - Polynomial& a: input/output coefficient Polynomial - **************************************************/ - void ntt() { - size_t j; - size_t k = 0; - - for(size_t len = 128; len > 0; len >>= 1) { - for(size_t start = 0; start < DilithiumModeConstants::N; start = j + len) { - int32_t zeta = DilithiumModeConstants::ZETAS[++k]; - for(j = start; j < start + len; ++j) { - int32_t t = montgomery_reduce(static_cast(zeta) * m_coeffs[j + len]); - m_coeffs[j + len] = m_coeffs[j] - t; - m_coeffs[j] = m_coeffs[j] + t; - } - } - } - } - - /************************************************* - * Name: poly_reduce - * - * Description: Inplace reduction of all coefficients of polynomial to - * representative in [-6283009,6283007]. - * For finite field element a with a <= 2^{31} - 2^{22} - 1, - * compute r \equiv a (mod Q) such that -6283009 <= r <= 6283007. - * - * Arguments: - Polynomial &a: reference to input polynomial - **************************************************/ - void poly_reduce() { - for(auto& i : m_coeffs) { - int32_t t = (i + (1 << 22)) >> 23; - t = i - t * DilithiumModeConstants::Q; - i = t; - } - } - - /************************************************* - * Name: invntt_tomont - * - * Description: Inverse NTT and multiplication by Montgomery factor 2^32. - * In-place. No modular reductions after additions or - * subtractions; input coefficients need to be smaller than - * Q in absolute value. Output coefficient are smaller than Q in - * absolute value. - **************************************************/ - void invntt_tomont() { - size_t j; - int32_t f = 41978; // mont^2/256 - size_t k = 256; - for(size_t len = 1; len < DilithiumModeConstants::N; len <<= 1) { - for(size_t start = 0; start < DilithiumModeConstants::N; start = j + len) { - int32_t zeta = -DilithiumModeConstants::ZETAS[--k]; - for(j = start; j < start + len; ++j) { - int32_t t = m_coeffs[j]; - m_coeffs[j] = t + m_coeffs[j + len]; - m_coeffs[j + len] = t - m_coeffs[j + len]; - m_coeffs[j + len] = montgomery_reduce(static_cast(zeta) * m_coeffs[j + len]); - } - } - } - - for(j = 0; j < DilithiumModeConstants::N; ++j) { - m_coeffs[j] = montgomery_reduce(static_cast(f) * m_coeffs[j]); - } - } - - /************************************************* - * Name: poly_invntt_tomont - * - * Description: Inplace inverse NTT and multiplication by 2^{32}. - * Input coefficients need to be less than Q in absolute - * value and output coefficients are again bounded by Q. - * - * Arguments: - Polynomial& a: reference to input/output polynomial - **************************************************/ - void poly_invntt_tomont() { invntt_tomont(); } - - /************************************************* - * Name: cadd_q - * - * Description: For all coefficients of in/out polynomial add Q if - * coefficient is negative. - * Add Q if input coefficient is negative. - **************************************************/ - void cadd_q() { - for(auto& i : m_coeffs) { - i += (i >> 31) & DilithiumModeConstants::Q; - } - } - - /************************************************* - * Name: poly_uniform_gamma1 - * - * Description: Sample polynomial with uniformly random coefficients - * in [-(GAMMA1 - 1), GAMMA1] by unpacking output stream - * of SHAKE256(seed|nonce) or AES256CTR(seed,nonce). - * - * Arguments: - const secure_vector& seed: vector with seed of length CRHBYTES - * - uint16_t nonce: 16-bit nonce - * - const DilithiumModeConstants& mode: reference to dilihtium mode values - **************************************************/ - void poly_uniform_gamma1(const secure_vector& seed, uint16_t nonce, const DilithiumModeConstants& mode) { - auto buf = mode.ExpandMask(seed, nonce); - - Polynomial::polyz_unpack(*this, buf.data(), mode); - } - - /************************************************* - * Name: poly_decompose - * - * Description: For all coefficients c of the input polynomial, - * compute high and low bits c0, c1 such c mod Q = c1*ALPHA + c0 - * with -ALPHA/2 < c0 <= ALPHA/2 except c1 = (Q-1)/ALPHA where we - * set c1 = 0 and -ALPHA/2 <= c0 = c mod Q - Q < 0. - * Assumes coefficients to be standard representatives. - * - * Arguments: - Polynomial& a1: reference to output polynomial with coefficients c1 - * - Polynomial& a0: reference to output polynomial with coefficients c0 - * - const DilithiumModeConstants& mode: reference to dilihtium mode values - **************************************************/ - void poly_decompose(Polynomial& a1, Polynomial& a0, const DilithiumModeConstants& mode) const { - for(size_t i = 0; i < DilithiumModeConstants::N; ++i) { - a1.m_coeffs[i] = Polynomial::decompose(&a0.m_coeffs[i], m_coeffs[i], mode); - } - } - - /************************************************* - * Name: poly_shiftl - * - * Description: Multiply polynomial by 2^D without modular reduction. Assumes - * input coefficients to be less than 2^{31-D} in absolute value. - * - * Arguments: - Polynomial& a: pointer to input/output polynomial - **************************************************/ - void poly_shiftl() { - for(size_t i = 0; i < m_coeffs.size(); ++i) { - m_coeffs[i] <<= DilithiumModeConstants::D; - } - } - - /************************************************* - * Name: polyw1_pack - * - * Description: Bit-pack polynomial w1 with coefficients in [0,15] or [0,43]. - * Input coefficients are assumed to be standard representatives. - * - * Arguments: - uint8_t *r: pointer to output byte array with at least - * POLYW1_PACKEDBYTES bytes - * - const DilithiumModeConstants& mode: reference to dilihtium mode values - **************************************************/ - void polyw1_pack(uint8_t* r, const DilithiumModeConstants& mode) { - if(mode.gamma2() == (DilithiumModeConstants::Q - 1) / 88) { - for(size_t i = 0; i < DilithiumModeConstants::N / 4; ++i) { - r[3 * i + 0] = static_cast(m_coeffs[4 * i + 0]); - r[3 * i + 0] |= static_cast(m_coeffs[4 * i + 1] << 6); - r[3 * i + 1] = static_cast(m_coeffs[4 * i + 1] >> 2); - r[3 * i + 1] |= static_cast(m_coeffs[4 * i + 2] << 4); - r[3 * i + 2] = static_cast(m_coeffs[4 * i + 2] >> 4); - r[3 * i + 2] |= static_cast(m_coeffs[4 * i + 3] << 2); - } - } else { - BOTAN_ASSERT_NOMSG(mode.gamma2() == (DilithiumModeConstants::Q - 1) / 32); - for(size_t i = 0; i < DilithiumModeConstants::N / 2; ++i) { - r[i] = static_cast(m_coeffs[2 * i + 0] | (m_coeffs[2 * i + 1] << 4)); - } - } - } - - /************************************************* - * Name: polyeta_unpack - * - * Description: Unpack polynomial with coefficients in [-ETA,ETA]. - * - * Arguments: - Polynomial& r: reference to output polynomial - * - const uint8_t *a: byte array with bit-packed_t1 polynomial - * - const DilithiumModeConstants& mode: reference to dilihtium mode values - **************************************************/ - static Polynomial polyeta_unpack(std::span a, const DilithiumModeConstants& mode) { - Polynomial r; - - switch(mode.eta()) { - case DilithiumEta::Eta2: { - for(size_t i = 0; i < DilithiumModeConstants::N / 8; ++i) { - r.m_coeffs[8 * i + 0] = (a[3 * i + 0] >> 0) & 7; - r.m_coeffs[8 * i + 1] = (a[3 * i + 0] >> 3) & 7; - r.m_coeffs[8 * i + 2] = ((a[3 * i + 0] >> 6) | (a[3 * i + 1] << 2)) & 7; - r.m_coeffs[8 * i + 3] = (a[3 * i + 1] >> 1) & 7; - r.m_coeffs[8 * i + 4] = (a[3 * i + 1] >> 4) & 7; - r.m_coeffs[8 * i + 5] = ((a[3 * i + 1] >> 7) | (a[3 * i + 2] << 1)) & 7; - r.m_coeffs[8 * i + 6] = (a[3 * i + 2] >> 2) & 7; - r.m_coeffs[8 * i + 7] = (a[3 * i + 2] >> 5) & 7; - - r.m_coeffs[8 * i + 0] = static_cast(mode.eta()) - r.m_coeffs[8 * i + 0]; - r.m_coeffs[8 * i + 1] = static_cast(mode.eta()) - r.m_coeffs[8 * i + 1]; - r.m_coeffs[8 * i + 2] = static_cast(mode.eta()) - r.m_coeffs[8 * i + 2]; - r.m_coeffs[8 * i + 3] = static_cast(mode.eta()) - r.m_coeffs[8 * i + 3]; - r.m_coeffs[8 * i + 4] = static_cast(mode.eta()) - r.m_coeffs[8 * i + 4]; - r.m_coeffs[8 * i + 5] = static_cast(mode.eta()) - r.m_coeffs[8 * i + 5]; - r.m_coeffs[8 * i + 6] = static_cast(mode.eta()) - r.m_coeffs[8 * i + 6]; - r.m_coeffs[8 * i + 7] = static_cast(mode.eta()) - r.m_coeffs[8 * i + 7]; - } - } break; - case DilithiumEta::Eta4: { - for(size_t i = 0; i < DilithiumModeConstants::N / 2; ++i) { - r.m_coeffs[2 * i + 0] = a[i] & 0x0F; - r.m_coeffs[2 * i + 1] = a[i] >> 4; - r.m_coeffs[2 * i + 0] = static_cast(mode.eta()) - r.m_coeffs[2 * i + 0]; - r.m_coeffs[2 * i + 1] = static_cast(mode.eta()) - r.m_coeffs[2 * i + 1]; - } - } break; - } - - return r; - } - - /************************************************* - * Name: polyeta_pack - * - * Description: Bit-pack polynomial with coefficients in [-ETA,ETA]. - * - * Arguments: - uint8_t *r: pointer to output byte array with at least - * POLYETA_PACKEDBYTES bytes - * - const Polynomial& a: pointer to input polynomial - * - const DilithiumModeConstants& mode: reference for dilithium mode values - **************************************************/ - void polyeta_pack(uint8_t* r, const DilithiumModeConstants& mode) const { - uint8_t t[8]; - - switch(mode.eta()) { - case DilithiumEta::Eta2: { - for(size_t i = 0; i < DilithiumModeConstants::N / 8; ++i) { - t[0] = static_cast(mode.eta() - m_coeffs[8 * i + 0]); - t[1] = static_cast(mode.eta() - m_coeffs[8 * i + 1]); - t[2] = static_cast(mode.eta() - m_coeffs[8 * i + 2]); - t[3] = static_cast(mode.eta() - m_coeffs[8 * i + 3]); - t[4] = static_cast(mode.eta() - m_coeffs[8 * i + 4]); - t[5] = static_cast(mode.eta() - m_coeffs[8 * i + 5]); - t[6] = static_cast(mode.eta() - m_coeffs[8 * i + 6]); - t[7] = static_cast(mode.eta() - m_coeffs[8 * i + 7]); - - r[3 * i + 0] = (t[0] >> 0) | (t[1] << 3) | (t[2] << 6); - r[3 * i + 1] = (t[2] >> 2) | (t[3] << 1) | (t[4] << 4) | (t[5] << 7); - r[3 * i + 2] = (t[5] >> 1) | (t[6] << 2) | (t[7] << 5); - } - } break; - case DilithiumEta::Eta4: { - for(size_t i = 0; i < DilithiumModeConstants::N / 2; ++i) { - t[0] = static_cast(mode.eta() - m_coeffs[2 * i + 0]); - t[1] = static_cast(mode.eta() - m_coeffs[2 * i + 1]); - r[i] = static_cast(t[0] | (t[1] << 4)); - } - } break; - } - } - - /************************************************* - * Name: polyt0_unpack - * - * Description: Unpack polynomial t0 with coefficients in ]-2^{D-1}, 2^{D-1}]. - * - * Arguments: - poly *r: pointer to output polynomial - * - const uint8_t *a: byte array with bit-packed_t1 polynomial - **************************************************/ - static Polynomial polyt0_unpack(std::span a) { - Polynomial r; - - for(size_t i = 0; i < DilithiumModeConstants::N / 8; ++i) { - r.m_coeffs[8 * i + 0] = a[13 * i + 0]; - r.m_coeffs[8 * i + 0] |= static_cast(a[13 * i + 1]) << 8; - r.m_coeffs[8 * i + 0] &= 0x1FFF; - - r.m_coeffs[8 * i + 1] = a[13 * i + 1] >> 5; - r.m_coeffs[8 * i + 1] |= static_cast(a[13 * i + 2]) << 3; - r.m_coeffs[8 * i + 1] |= static_cast(a[13 * i + 3]) << 11; - r.m_coeffs[8 * i + 1] &= 0x1FFF; - - r.m_coeffs[8 * i + 2] = a[13 * i + 3] >> 2; - r.m_coeffs[8 * i + 2] |= static_cast(a[13 * i + 4]) << 6; - r.m_coeffs[8 * i + 2] &= 0x1FFF; - - r.m_coeffs[8 * i + 3] = a[13 * i + 4] >> 7; - r.m_coeffs[8 * i + 3] |= static_cast(a[13 * i + 5]) << 1; - r.m_coeffs[8 * i + 3] |= static_cast(a[13 * i + 6]) << 9; - r.m_coeffs[8 * i + 3] &= 0x1FFF; - - r.m_coeffs[8 * i + 4] = a[13 * i + 6] >> 4; - r.m_coeffs[8 * i + 4] |= static_cast(a[13 * i + 7]) << 4; - r.m_coeffs[8 * i + 4] |= static_cast(a[13 * i + 8]) << 12; - r.m_coeffs[8 * i + 4] &= 0x1FFF; - - r.m_coeffs[8 * i + 5] = a[13 * i + 8] >> 1; - r.m_coeffs[8 * i + 5] |= static_cast(a[13 * i + 9]) << 7; - r.m_coeffs[8 * i + 5] &= 0x1FFF; - - r.m_coeffs[8 * i + 6] = a[13 * i + 9] >> 6; - r.m_coeffs[8 * i + 6] |= static_cast(a[13 * i + 10]) << 2; - r.m_coeffs[8 * i + 6] |= static_cast(a[13 * i + 11]) << 10; - r.m_coeffs[8 * i + 6] &= 0x1FFF; - - r.m_coeffs[8 * i + 7] = a[13 * i + 11] >> 3; - r.m_coeffs[8 * i + 7] |= static_cast(a[13 * i + 12]) << 5; - r.m_coeffs[8 * i + 7] &= 0x1FFF; - - r.m_coeffs[8 * i + 0] = (1 << (DilithiumModeConstants::D - 1)) - r.m_coeffs[8 * i + 0]; - r.m_coeffs[8 * i + 1] = (1 << (DilithiumModeConstants::D - 1)) - r.m_coeffs[8 * i + 1]; - r.m_coeffs[8 * i + 2] = (1 << (DilithiumModeConstants::D - 1)) - r.m_coeffs[8 * i + 2]; - r.m_coeffs[8 * i + 3] = (1 << (DilithiumModeConstants::D - 1)) - r.m_coeffs[8 * i + 3]; - r.m_coeffs[8 * i + 4] = (1 << (DilithiumModeConstants::D - 1)) - r.m_coeffs[8 * i + 4]; - r.m_coeffs[8 * i + 5] = (1 << (DilithiumModeConstants::D - 1)) - r.m_coeffs[8 * i + 5]; - r.m_coeffs[8 * i + 6] = (1 << (DilithiumModeConstants::D - 1)) - r.m_coeffs[8 * i + 6]; - r.m_coeffs[8 * i + 7] = (1 << (DilithiumModeConstants::D - 1)) - r.m_coeffs[8 * i + 7]; - } - - return r; - } - - /************************************************* - * Name: polyt0_pack - * - * Description: Bit-pack polynomial t0 with coefficients in ]-2^{D-1}, 2^{D-1}]. - * - * Arguments: - uint8_t *r: pointer to output byte array with at least - * POLYT0_PACKEDBYTES bytes - * - const Polynomial& a: reference to input polynomial - **************************************************/ - void polyt0_pack(uint8_t* r) const { - uint32_t t[8]; - for(size_t i = 0; i < DilithiumModeConstants::N / 8; ++i) { - t[0] = (1 << (DilithiumModeConstants::D - 1)) - m_coeffs[8 * i + 0]; - t[1] = (1 << (DilithiumModeConstants::D - 1)) - m_coeffs[8 * i + 1]; - t[2] = (1 << (DilithiumModeConstants::D - 1)) - m_coeffs[8 * i + 2]; - t[3] = (1 << (DilithiumModeConstants::D - 1)) - m_coeffs[8 * i + 3]; - t[4] = (1 << (DilithiumModeConstants::D - 1)) - m_coeffs[8 * i + 4]; - t[5] = (1 << (DilithiumModeConstants::D - 1)) - m_coeffs[8 * i + 5]; - t[6] = (1 << (DilithiumModeConstants::D - 1)) - m_coeffs[8 * i + 6]; - t[7] = (1 << (DilithiumModeConstants::D - 1)) - m_coeffs[8 * i + 7]; - - r[13 * i + 0] = static_cast(t[0]); - r[13 * i + 1] = static_cast(t[0] >> 8); - r[13 * i + 1] |= static_cast(t[1] << 5); - r[13 * i + 2] = static_cast(t[1] >> 3); - r[13 * i + 3] = static_cast(t[1] >> 11); - r[13 * i + 3] |= static_cast(t[2] << 2); - r[13 * i + 4] = static_cast(t[2] >> 6); - r[13 * i + 4] |= static_cast(t[3] << 7); - r[13 * i + 5] = static_cast(t[3] >> 1); - r[13 * i + 6] = static_cast(t[3] >> 9); - r[13 * i + 6] |= static_cast(t[4] << 4); - r[13 * i + 7] = static_cast(t[4] >> 4); - r[13 * i + 8] = static_cast(t[4] >> 12); - r[13 * i + 8] |= static_cast(t[5] << 1); - r[13 * i + 9] = static_cast(t[5] >> 7); - r[13 * i + 9] |= static_cast(t[6] << 6); - r[13 * i + 10] = static_cast(t[6] >> 2); - r[13 * i + 11] = static_cast(t[6] >> 10); - r[13 * i + 11] |= static_cast(t[7] << 3); - r[13 * i + 12] = static_cast(t[7] >> 5); - } - } - - /************************************************* - * Name: polyz_unpack - * - * Description: Unpack polynomial z with coefficients - * in [-(GAMMA1 - 1), GAMMA1]. - * - * Arguments: - Polynomial& r: pointer to output polynomial - * - const uint8_t *a: byte array with bit-packed_t1 polynomial - * - const DilithiumModeConstants& mode: reference to dilihtium mode values - **************************************************/ - static void polyz_unpack(Polynomial& r, const uint8_t* a, const DilithiumModeConstants& mode) { - if(mode.gamma1() == (1 << 17)) { - for(size_t i = 0; i < DilithiumModeConstants::N / 4; ++i) { - r.m_coeffs[4 * i + 0] = a[9 * i + 0]; - r.m_coeffs[4 * i + 0] |= static_cast(a[9 * i + 1]) << 8; - r.m_coeffs[4 * i + 0] |= static_cast(a[9 * i + 2]) << 16; - r.m_coeffs[4 * i + 0] &= 0x3FFFF; - - r.m_coeffs[4 * i + 1] = a[9 * i + 2] >> 2; - r.m_coeffs[4 * i + 1] |= static_cast(a[9 * i + 3]) << 6; - r.m_coeffs[4 * i + 1] |= static_cast(a[9 * i + 4]) << 14; - r.m_coeffs[4 * i + 1] &= 0x3FFFF; - - r.m_coeffs[4 * i + 2] = a[9 * i + 4] >> 4; - r.m_coeffs[4 * i + 2] |= static_cast(a[9 * i + 5]) << 4; - r.m_coeffs[4 * i + 2] |= static_cast(a[9 * i + 6]) << 12; - r.m_coeffs[4 * i + 2] &= 0x3FFFF; - - r.m_coeffs[4 * i + 3] = a[9 * i + 6] >> 6; - r.m_coeffs[4 * i + 3] |= static_cast(a[9 * i + 7]) << 2; - r.m_coeffs[4 * i + 3] |= static_cast(a[9 * i + 8]) << 10; - r.m_coeffs[4 * i + 3] &= 0x3FFFF; - - r.m_coeffs[4 * i + 0] = static_cast(mode.gamma1()) - r.m_coeffs[4 * i + 0]; - r.m_coeffs[4 * i + 1] = static_cast(mode.gamma1()) - r.m_coeffs[4 * i + 1]; - r.m_coeffs[4 * i + 2] = static_cast(mode.gamma1()) - r.m_coeffs[4 * i + 2]; - r.m_coeffs[4 * i + 3] = static_cast(mode.gamma1()) - r.m_coeffs[4 * i + 3]; - } - } else if(mode.gamma1() == (1 << 19)) { - for(size_t i = 0; i < DilithiumModeConstants::N / 2; ++i) { - r.m_coeffs[2 * i + 0] = a[5 * i + 0]; - r.m_coeffs[2 * i + 0] |= static_cast(a[5 * i + 1]) << 8; - r.m_coeffs[2 * i + 0] |= static_cast(a[5 * i + 2]) << 16; - r.m_coeffs[2 * i + 0] &= 0xFFFFF; - - r.m_coeffs[2 * i + 1] = a[5 * i + 2] >> 4; - r.m_coeffs[2 * i + 1] |= static_cast(a[5 * i + 3]) << 4; - r.m_coeffs[2 * i + 1] |= static_cast(a[5 * i + 4]) << 12; - r.m_coeffs[2 * i + 0] &= 0xFFFFF; - - r.m_coeffs[2 * i + 0] = static_cast(mode.gamma1()) - r.m_coeffs[2 * i + 0]; - r.m_coeffs[2 * i + 1] = static_cast(mode.gamma1()) - r.m_coeffs[2 * i + 1]; - } - } - } - - /************************************************* - * Name: polyz_pack - * - * Description: Bit-pack polynomial with coefficients - * in [-(GAMMA1 - 1), GAMMA1]. - * - * Arguments: - uint8_t *r: pointer to output byte array with at least - * POLYZ_PACKEDBYTES bytes - * - const DilithiumModeConstants& mode: reference to dilihtium mode values - **************************************************/ - void polyz_pack(uint8_t* r, const DilithiumModeConstants& mode) const { - uint32_t t[4]; - if(mode.gamma1() == (1 << 17)) { - for(size_t i = 0; i < DilithiumModeConstants::N / 4; ++i) { - t[0] = static_cast(mode.gamma1()) - m_coeffs[4 * i + 0]; - t[1] = static_cast(mode.gamma1()) - m_coeffs[4 * i + 1]; - t[2] = static_cast(mode.gamma1()) - m_coeffs[4 * i + 2]; - t[3] = static_cast(mode.gamma1()) - m_coeffs[4 * i + 3]; - - r[9 * i + 0] = static_cast(t[0]); - r[9 * i + 1] = static_cast(t[0] >> 8); - r[9 * i + 2] = static_cast(t[0] >> 16); - r[9 * i + 2] |= static_cast(t[1] << 2); - r[9 * i + 3] = static_cast(t[1] >> 6); - r[9 * i + 4] = static_cast(t[1] >> 14); - r[9 * i + 4] |= static_cast(t[2] << 4); - r[9 * i + 5] = static_cast(t[2] >> 4); - r[9 * i + 6] = static_cast(t[2] >> 12); - r[9 * i + 6] |= static_cast(t[3] << 6); - r[9 * i + 7] = static_cast(t[3] >> 2); - r[9 * i + 8] = static_cast(t[3] >> 10); - } - } else if(mode.gamma1() == (1 << 19)) { - for(size_t i = 0; i < DilithiumModeConstants::N / 2; ++i) { - t[0] = static_cast(mode.gamma1()) - m_coeffs[2 * i + 0]; - t[1] = static_cast(mode.gamma1()) - m_coeffs[2 * i + 1]; - - r[5 * i + 0] = static_cast(t[0]); - r[5 * i + 1] = static_cast(t[0] >> 8); - r[5 * i + 2] = static_cast(t[0] >> 16); - r[5 * i + 2] |= static_cast(t[1] << 4); - r[5 * i + 3] = static_cast(t[1] >> 4); - r[5 * i + 4] = static_cast(t[1] >> 12); - } - } - } - - /************************************************* - * Name: polyt1_unpack - * - * Description: Unpack polynomial t1 with 10-bit coefficients. - * Output coefficients are standard representatives. - * - * Arguments: - Polynomial& r: pointer to output polynomial - * - const uint8_t *a: byte array with bit-packed_t1 polynomial - **************************************************/ - static void polyt1_unpack(Polynomial& r, const uint8_t* a) { - for(size_t i = 0; i < DilithiumModeConstants::N / 4; ++i) { - r.m_coeffs[4 * i + 0] = ((a[5 * i + 0] >> 0) | (static_cast(a[5 * i + 1]) << 8)) & 0x3FF; - r.m_coeffs[4 * i + 1] = ((a[5 * i + 1] >> 2) | (static_cast(a[5 * i + 2]) << 6)) & 0x3FF; - r.m_coeffs[4 * i + 2] = ((a[5 * i + 2] >> 4) | (static_cast(a[5 * i + 3]) << 4)) & 0x3FF; - r.m_coeffs[4 * i + 3] = ((a[5 * i + 3] >> 6) | (static_cast(a[5 * i + 4]) << 2)) & 0x3FF; - } - } - - /************************************************* - * Name: polyt1_pack - * - * Description: Bit-pack polynomial t1 with coefficients fitting in 10 bits. - * Input coefficients are assumed to be standard representatives. - * - * Arguments: - uint8_t *r: pointer to output byte array with at least - * POLYT1_PACKEDBYTES bytes - **************************************************/ - void polyt1_pack(uint8_t* r) const { - for(size_t i = 0; i < DilithiumModeConstants::N / 4; ++i) { - r[5 * i + 0] = static_cast((m_coeffs[4 * i + 0] >> 0)); - r[5 * i + 1] = static_cast((m_coeffs[4 * i + 0] >> 8) | (m_coeffs[4 * i + 1] << 2)); - r[5 * i + 2] = static_cast((m_coeffs[4 * i + 1] >> 6) | (m_coeffs[4 * i + 2] << 4)); - r[5 * i + 3] = static_cast((m_coeffs[4 * i + 2] >> 4) | (m_coeffs[4 * i + 3] << 6)); - r[5 * i + 4] = static_cast((m_coeffs[4 * i + 3] >> 2)); - } - } - - Polynomial() = default; -}; - -class PolynomialVector { - public: - // public member is on purpose - std::vector m_vec; - - public: - PolynomialVector() = default; - - PolynomialVector& operator+=(const PolynomialVector& other) { - BOTAN_ASSERT_NOMSG(m_vec.size() != other.m_vec.size()); - for(size_t i = 0; i < m_vec.size(); ++i) { - this->m_vec[i] += other.m_vec[i]; - } - return *this; - } - - PolynomialVector& operator-=(const PolynomialVector& other) { - BOTAN_ASSERT_NOMSG(m_vec.size() == other.m_vec.size()); - for(size_t i = 0; i < this->m_vec.size(); ++i) { - this->m_vec[i] -= other.m_vec[i]; - } - return *this; - } - - explicit PolynomialVector(size_t size) : m_vec(size) {} - - /************************************************* - * Name: poly_uniform - * - * Description: Sample polynomial with uniformly random coefficients - * in [0,Q-1] by performing rejection sampling on the - * output stream of SHAKE256(seed|nonce) or AES256CTR(seed,nonce). - * - * Arguments: - const uint8_t seed[]: secure vector with seed of length SEEDBYTES - * - uint16_t nonce: 2-byte nonce - * - const DilithiumModeConstants& mode: reference to dilihtium mode values - * Return Polynomial - **************************************************/ - static Polynomial poly_uniform(const std::vector& seed, - uint16_t nonce, - const DilithiumModeConstants& mode) { - Polynomial sample_poly; - size_t buflen = mode.poly_uniform_nblocks() * mode.stream128_blockbytes(); - - std::vector buf(buflen + 2); - - auto xof = mode.XOF_128(seed, nonce); - xof->output(std::span(buf).first(buflen)); - - size_t ctr = Polynomial::rej_uniform(sample_poly, 0, DilithiumModeConstants::N, buf.data(), buflen); - size_t off; - while(ctr < DilithiumModeConstants::N) { - off = buflen % 3; - for(size_t i = 0; i < off; ++i) { - buf[i] = buf[buflen - off + i]; - } - - xof->output(std::span(buf).subspan(off, mode.stream128_blockbytes())); - buflen = mode.stream128_blockbytes() + off; - ctr += Polynomial::rej_uniform(sample_poly, ctr, DilithiumModeConstants::N - ctr, buf.data(), buflen); - } - return sample_poly; - } - - static void fill_polyvec_uniform_eta(PolynomialVector& v, - const secure_vector& seed, - uint16_t nonce, - const DilithiumModeConstants& mode) { - for(size_t i = 0; i < v.m_vec.size(); ++i) { - Polynomial::fill_poly_uniform_eta(v.m_vec[i], seed, nonce++, mode); - } - } - - /************************************************* - * Name: polyvec_pointwise_acc_montgomery - * - * Description: Pointwise multiply vectors of polynomials of length L, multiply - * resulting vector by 2^{-32} and add (accumulate) polynomials - * in it. Input/output vectors are in NTT domain representation. - * - * Arguments: - Polynomial &w: output polynomial - * - const Polynomial &u: pointer to first input vector - * - const Polynomial &v: pointer to second input vector - **************************************************/ - static void polyvec_pointwise_acc_montgomery(Polynomial& w, - const PolynomialVector& u, - const PolynomialVector& v) { - BOTAN_ASSERT_NOMSG(u.m_vec.size() == v.m_vec.size()); - BOTAN_ASSERT_NOMSG(!u.m_vec.empty() && !v.m_vec.empty()); - - u.m_vec[0].poly_pointwise_montgomery(w, v.m_vec[0]); - - for(size_t i = 1; i < v.m_vec.size(); ++i) { - Polynomial t; - u.m_vec[i].poly_pointwise_montgomery(t, v.m_vec[i]); - w += t; - } - } - - /************************************************* - * Name: fill_polyvecs_power2round - * - * Description: For all coefficients a of polynomials in vector , - * compute a0, a1 such that a mod^+ Q = a1*2^D + a0 - * with -2^{D-1} < a0 <= 2^{D-1}. Assumes coefficients to be - * standard representatives. - * - * Arguments: - PolynomialVector& v1: reference to output vector of polynomials with - * coefficients a1 - * - PolynomialVector& v0: reference to output vector of polynomials with - * coefficients a0 - * - const PolynomialVector& v: reference to input vector - **************************************************/ - static void fill_polyvecs_power2round(PolynomialVector& v1, PolynomialVector& v0, const PolynomialVector& v) { - BOTAN_ASSERT((v1.m_vec.size() == v0.m_vec.size()) && (v1.m_vec.size() == v.m_vec.size()), - "possible buffer overflow! Wrong PolynomialVector sizes."); - for(size_t i = 0; i < v1.m_vec.size(); ++i) { - Polynomial::fill_polys_power2round(v1.m_vec[i], v0.m_vec[i], v.m_vec[i]); - } - } - - static bool unpack_sig(std::array& c, - PolynomialVector& z, - PolynomialVector& h, - const std::vector& sig, - const DilithiumModeConstants& mode) { - //const auto& mode = m_pub_key.m_public->mode(); - BOTAN_ASSERT(sig.size() == mode.crypto_bytes(), "invalid signature size"); - size_t position = 0; - - std::copy(sig.begin(), sig.begin() + c.size(), c.begin()); - - position += DilithiumModeConstants::SEEDBYTES; - - for(size_t i = 0; i < mode.l(); ++i) { - Polynomial::polyz_unpack(z.m_vec[i], sig.data() + position + i * mode.polyz_packedbytes(), mode); - } - position += mode.l() * mode.polyz_packedbytes(); - - /* Decode h */ - size_t k = 0; - for(size_t i = 0; i < mode.k(); ++i) { - for(size_t j = 0; j < DilithiumModeConstants::N; ++j) { - h.m_vec[i].m_coeffs[j] = 0; - } - - if(sig[position + mode.omega() + i] < k || sig[position + mode.omega() + i] > mode.omega()) { - return true; - } - - for(size_t j = k; j < sig[position + mode.omega() + i]; ++j) { - /* Coefficients are ordered for strong unforgeability */ - if(j > k && sig[position + j] <= sig[position + j - 1]) { - return true; - } - h.m_vec[i].m_coeffs[sig[position + j]] = 1; - } - - k = sig[position + mode.omega() + i]; - } - - /* Extra indices are zero for strong unforgeability */ - for(size_t j = k; j < mode.omega(); ++j) { - if(sig[position + j]) { - return true; - } - } - - return false; - } - - /************************************************* - * Name: generate_hint_polyvec - * - * Description: Compute hint vector. - * - * Arguments: - PolynomialVector *h: reference to output vector - * - const PolynomialVector *v0: reference to low part of input vector - * - const PolynomialVector *v1: reference to high part of input vector - * - const DilithiumModeConstants& mode: reference to dilihtium mode values - * - * Returns number of 1 bits. - **************************************************/ - static size_t generate_hint_polyvec(PolynomialVector& h, - const PolynomialVector& v0, - const PolynomialVector& v1, - const DilithiumModeConstants& mode) { - size_t s = 0; - - for(size_t i = 0; i < h.m_vec.size(); ++i) { - s += Polynomial::generate_hint_polynomial(h.m_vec[i], v0.m_vec[i], v1.m_vec[i], mode); - } - - return s; - } - - /************************************************* - * Name: ntt - * - * Description: Forward NTT of all polynomials in vector. Output - * coefficients can be up to 16*Q larger than input coefficients. - **************************************************/ - void ntt() { - for(auto& i : m_vec) { - i.ntt(); - } - } - - /************************************************* - * Name: polyveck_decompose - * - * Description: For all coefficients a of polynomials in vector, - * compute high and low bits a0, a1 such a mod^+ Q = a1*ALPHA + a0 - * with -ALPHA/2 < a0 <= ALPHA/2 except a1 = (Q-1)/ALPHA where we - * set a1 = 0 and -ALPHA/2 <= a0 = a mod Q - Q < 0. - * Assumes coefficients to be standard representatives. - * - * Arguments: - PolynomialVector& v1: reference to output vector of polynomials with - * coefficients a1 - * - PolynomialVector& v0: reference to output vector of polynomials with - * coefficients a0 - * - const PolynomialVector& v: reference to input vector - **************************************************/ - std::tuple polyvec_decompose(const DilithiumModeConstants& mode) { - PolynomialVector v1(mode.k()); - PolynomialVector v0(mode.k()); - - for(size_t i = 0; i < m_vec.size(); ++i) { - m_vec[i].poly_decompose(v1.m_vec[i], v0.m_vec[i], mode); - } - return std::make_tuple(v1, v0); - } - - /************************************************* - * Name: reduce - * - * Description: Reduce coefficients of polynomials in vector - * to representatives in [-6283009,6283007]. - **************************************************/ - void reduce() { - for(auto& i : m_vec) { - i.poly_reduce(); - } - } - - /************************************************* - * Name: invntt_tomont - * - * Description: Inverse NTT and multiplication by 2^{32} of polynomials - * in vector. Input coefficients need to be less - * than 2*Q. - **************************************************/ - void invntt_tomont() { - for(auto& i : m_vec) { - i.poly_invntt_tomont(); - } - } - - /************************************************* - * Name: add_polyvec - * - * Description: Add vectors of polynomials . - * No modular reduction is performed. - * - * Arguments: - const PolynomialVector *v: pointer to second summand - * - const PolynomialVector *u: pointer to first summand - **************************************************/ - void add_polyvec(const PolynomialVector& v) { - BOTAN_ASSERT((m_vec.size() == v.m_vec.size()), "possible buffer overflow! Wrong PolynomialVector sizes."); - for(size_t i = 0; i < m_vec.size(); ++i) { - m_vec[i] += v.m_vec[i]; - } - } - - /************************************************* - * Name: cadd_q - * - * Description: For all coefficients of polynomials in vector - * add Q if coefficient is negative. - **************************************************/ - void cadd_q() { - for(auto& i : m_vec) { - i.cadd_q(); - } - } - - void polyvecl_uniform_gamma1(const secure_vector& seed, - uint16_t nonce, - const DilithiumModeConstants& mode) { - BOTAN_ASSERT_NOMSG(m_vec.size() <= std::numeric_limits::max()); - for(uint16_t i = 0; i < static_cast(this->m_vec.size()); ++i) { - m_vec[i].poly_uniform_gamma1(seed, mode.l() * nonce + i, mode); - } - } - - void polyvec_pointwise_poly_montgomery(PolynomialVector& r, const Polynomial& a) { - for(size_t i = 0; i < m_vec.size(); ++i) { - m_vec[i].poly_pointwise_montgomery(r.m_vec[i], a); - } - } - - /************************************************* - * Name: polyvecl_chknorm - * - * Description: Check infinity norm of polynomials in vector of length L. - * Assumes input polyvecl to be reduced by polyvecl_reduce(). - * - * Arguments: - size_t B: norm bound - * - * Returns false if norm of all polynomials is strictly smaller than B <= (Q-1)/8 - * and true otherwise. - **************************************************/ - bool polyvec_chknorm(size_t bound) { - for(auto& i : m_vec) { - if(Polynomial::poly_chknorm(i, bound)) { - return true; - } - } - return false; - } - - /************************************************* - * Name: polyvec_shiftl - * - * Description: Multiply vector of polynomials by 2^D without modular - * reduction. Assumes input coefficients to be less than 2^{31-D}. - **************************************************/ - void polyvec_shiftl() { - for(auto& i : m_vec) { - i.poly_shiftl(); - } - } - - /************************************************* - * Name: polyvec_use_hint - * - * Description: Use hint vector to correct the high bits of input vector. - * - * Arguments: - PolynomialVector& w: reference to output vector of polynomials with - * corrected high bits - * - const PolynomialVector& u: reference to input vector - * - const PolynomialVector& h: reference to input hint vector - * - const DilithiumModeConstants& mode: reference to dilihtium mode values - **************************************************/ - void polyvec_use_hint(PolynomialVector& w, const PolynomialVector& h, const DilithiumModeConstants& mode) { - for(size_t i = 0; i < w.m_vec.size(); ++i) { - Polynomial::poly_use_hint(w.m_vec[i], m_vec[i], h.m_vec[i], mode); - } - } - - secure_vector polyvec_pack_eta(const DilithiumModeConstants& mode) const { - secure_vector packed_eta(mode.polyeta_packedbytes() * m_vec.size()); - for(size_t i = 0; i < m_vec.size(); ++i) { - m_vec[i].polyeta_pack(packed_eta.data() + mode.polyeta_packedbytes() * i, mode); - } - return packed_eta; - } - - static PolynomialVector unpack_eta(std::span buffer, - size_t size, - const DilithiumModeConstants& mode) { - BOTAN_ARG_CHECK(buffer.size() == mode.polyeta_packedbytes() * size, "Invalid buffer size"); - - PolynomialVector pv(size); - for(size_t i = 0; i < pv.m_vec.size(); ++i) { - pv.m_vec[i] = Polynomial::polyeta_unpack( - buffer.subspan(i * mode.polyeta_packedbytes(), mode.polyeta_packedbytes()), mode); - } - return pv; - } - - secure_vector polyvec_pack_t0() const { - secure_vector packed_t0(m_vec.size() * DilithiumModeConstants::POLYT0_PACKEDBYTES); - for(size_t i = 0; i < m_vec.size(); ++i) { - m_vec[i].polyt0_pack(packed_t0.data() + i * DilithiumModeConstants::POLYT0_PACKEDBYTES); - } - return packed_t0; - } - - static PolynomialVector unpack_t0(std::span buffer, const DilithiumModeConstants& mode) { - BOTAN_ARG_CHECK(static_cast(buffer.size()) == DilithiumModeConstants::POLYT0_PACKEDBYTES * mode.k(), - "Invalid buffer size"); - - PolynomialVector t0(mode.k()); - for(size_t i = 0; i < t0.m_vec.size(); ++i) { - t0.m_vec[i] = Polynomial::polyt0_unpack(buffer.subspan(i * DilithiumModeConstants::POLYT0_PACKEDBYTES, - DilithiumModeConstants::POLYT0_PACKEDBYTES)); - } - return t0; - } - - std::vector polyvec_pack_t1() const { - std::vector packed_t1(m_vec.size() * DilithiumModeConstants::POLYT1_PACKEDBYTES); - for(size_t i = 0; i < m_vec.size(); ++i) { - m_vec[i].polyt1_pack(packed_t1.data() + i * DilithiumModeConstants::POLYT1_PACKEDBYTES); - } - return packed_t1; - } - - static PolynomialVector unpack_t1(std::span packed_t1, const DilithiumModeConstants& mode) { - BOTAN_ARG_CHECK( - static_cast(packed_t1.size()) == DilithiumModeConstants::POLYT1_PACKEDBYTES * mode.k(), - "Invalid buffer size"); - - PolynomialVector t1(mode.k()); - for(size_t i = 0; i < t1.m_vec.size(); ++i) { - Polynomial::polyt1_unpack(t1.m_vec[i], packed_t1.data() + i * DilithiumModeConstants::POLYT1_PACKEDBYTES); - } - return t1; - } - - std::vector polyvec_pack_w1(const DilithiumModeConstants& mode) { - std::vector packed_w1(mode.polyw1_packedbytes() * m_vec.size()); - for(size_t i = 0; i < m_vec.size(); ++i) { - m_vec[i].polyw1_pack(packed_w1.data() + i * mode.polyw1_packedbytes(), mode); - } - return packed_w1; - } - - static PolynomialVector polyvec_unpack_z(const uint8_t* packed_z, const DilithiumModeConstants& mode) { - PolynomialVector z(mode.l()); - for(size_t i = 0; i < z.m_vec.size(); ++i) { - Polynomial::polyz_unpack(z.m_vec[i], packed_z + i * mode.polyz_packedbytes(), mode); - } - return z; - } - - /************************************************* - * Name: generate_polyvec_matrix_pointwise_montgomery - * - * Description: Generates a PolynomialVector based on a matrix using pointwise montgomery acc - * - * Arguments: - const std::vector& rho[]: byte array containing seed rho - * - const DilithiumModeConstants& mode: reference to dilihtium mode values - * Returns a PolynomialVector - **************************************************/ - static PolynomialVector generate_polyvec_matrix_pointwise_montgomery(const std::vector& mat, - const PolynomialVector& v, - const DilithiumModeConstants& mode) { - PolynomialVector t(mode.k()); - for(size_t i = 0; i < mode.k(); ++i) { - PolynomialVector::polyvec_pointwise_acc_montgomery(t.m_vec[i], mat[i], v); - } - return t; - } -}; - -class PolynomialMatrix { - private: - // Matrix of length k holding a polynomialVector of size l, which has N coeffs - std::vector m_mat; - - explicit PolynomialMatrix(const DilithiumModeConstants& mode) : m_mat(mode.k(), PolynomialVector(mode.l())) {} - - public: - PolynomialMatrix() = delete; - - /************************************************* - * Name: generate_matrix - * - * Description: Implementation of generate_matrix. Generates matrix A with uniformly - * random coefficients a_{i,j} by performing rejection - * sampling on the output stream of SHAKE128(rho|j|i) - * or AES256CTR(rho,j|i). - * - * Arguments: - const std::vector& rho[]: byte array containing seed rho - * - const DilithiumModeConstants& mode: reference to dilihtium mode values - * Returns the output matrix mat[k] - **************************************************/ - static PolynomialMatrix generate_matrix(const std::vector& rho, const DilithiumModeConstants& mode) { - BOTAN_ASSERT(rho.size() >= DilithiumModeConstants::SEEDBYTES, "wrong byte length for rho/seed"); - - PolynomialMatrix matrix(mode); - for(uint16_t i = 0; i < mode.k(); ++i) { - for(uint16_t j = 0; j < mode.l(); ++j) { - matrix.m_mat[i].m_vec[j] = PolynomialVector::poly_uniform(rho, (i << 8) + j, mode); - } - } - return matrix; - } - - const std::vector& get_matrix() const { return m_mat; } -}; -} // namespace Botan::Dilithium - -#endif diff --git a/src/lib/pubkey/dilithium/dilithium_common/dilithium_symmetric_primitives.cpp b/src/lib/pubkey/dilithium/dilithium_common/dilithium_symmetric_primitives.cpp index 6d7b4aed006..8fafbd7ddc9 100644 --- a/src/lib/pubkey/dilithium/dilithium_common/dilithium_symmetric_primitives.cpp +++ b/src/lib/pubkey/dilithium/dilithium_common/dilithium_symmetric_primitives.cpp @@ -19,97 +19,20 @@ namespace Botan { -std::unique_ptr Dilithium_Symmetric_Primitives::create(DilithiumMode mode) { +std::unique_ptr Dilithium_Symmetric_Primitives::create(const DilithiumConstants& mode) { #if BOTAN_HAS_DILITHIUM if(mode.is_modern()) { - return std::make_unique(); + return std::make_unique(mode.commitment_hash_full_bytes()); } #endif #if BOTAN_HAS_DILITHIUM_AES if(mode.is_aes()) { - return std::make_unique(); + return std::make_unique(mode.commitment_hash_full_bytes()); } #endif throw Not_Implemented("requested Dilithium mode is not enabled in this build"); } -DilithiumModeConstants::DilithiumModeConstants(DilithiumMode mode) : - m_mode(mode), m_symmetric_primitives(Dilithium_Symmetric_Primitives::create(mode)) { - if(mode.is_modern()) { - m_stream128_blockbytes = DilithiumModeConstants::SHAKE128_RATE; - m_stream256_blockbytes = DilithiumModeConstants::SHAKE256_RATE; - } else { - m_stream128_blockbytes = AES256CTR_BLOCKBYTES; - m_stream256_blockbytes = AES256CTR_BLOCKBYTES; - } - - switch(m_mode.mode()) { - case Botan::DilithiumMode::Dilithium4x4: - case Botan::DilithiumMode::Dilithium4x4_AES: - m_k = 4; - m_l = 4; - m_eta = DilithiumEta::Eta2; - m_tau = 39; - m_beta = 78; - m_gamma1 = (1 << 17); - m_gamma2 = ((DilithiumModeConstants::Q - 1) / 88); - m_omega = 80; - m_nist_security_strength = 128; - m_polyz_packedbytes = 576; - m_polyw1_packedbytes = 192; - m_polyeta_packedbytes = 96; - m_poly_uniform_eta_nblocks = ((136 + m_stream128_blockbytes - 1) / m_stream128_blockbytes); - break; - case Botan::DilithiumMode::Dilithium6x5: - case Botan::DilithiumMode::Dilithium6x5_AES: - m_k = 6; - m_l = 5; - m_eta = DilithiumEta::Eta4; - m_tau = 49; - m_beta = 196; - m_gamma1 = (1 << 19); - m_gamma2 = ((DilithiumModeConstants::Q - 1) / 32); - m_omega = 55; - m_nist_security_strength = 192; - m_polyz_packedbytes = 640; - m_polyw1_packedbytes = 128; - m_polyeta_packedbytes = 128; - m_poly_uniform_eta_nblocks = ((227 + m_stream128_blockbytes - 1) / m_stream128_blockbytes); - break; - case Botan::DilithiumMode::Dilithium8x7: - case Botan::DilithiumMode::Dilithium8x7_AES: - m_k = 8; - m_l = 7; - m_eta = DilithiumEta::Eta2; - m_tau = 60; - m_beta = 120; - m_gamma1 = (1 << 19); - m_gamma2 = ((DilithiumModeConstants::Q - 1) / 32); - m_omega = 75; - m_nist_security_strength = 256; - m_polyz_packedbytes = 640; - m_polyw1_packedbytes = 128; - m_polyeta_packedbytes = 96; - m_poly_uniform_eta_nblocks = ((136 + m_stream128_blockbytes - 1) / m_stream128_blockbytes); - break; - } - - if(m_gamma1 == (1 << 17)) { - m_poly_uniform_gamma1_nblocks = (576 + m_stream256_blockbytes - 1) / m_stream256_blockbytes; - } else { - BOTAN_ASSERT_NOMSG(m_gamma1 == (1 << 19)); - m_poly_uniform_gamma1_nblocks = (640 + m_stream256_blockbytes - 1) / m_stream256_blockbytes; - } - - // For all modes the same calculation - m_polyvech_packedbytes = m_omega + m_k; - m_poly_uniform_nblocks = ((768 + m_stream128_blockbytes - 1) / m_stream128_blockbytes); - m_public_key_bytes = DilithiumModeConstants::SEEDBYTES + m_k * DilithiumModeConstants::POLYT1_PACKEDBYTES; - m_crypto_bytes = DilithiumModeConstants::SEEDBYTES + m_l * m_polyz_packedbytes + m_polyvech_packedbytes; - m_private_key_bytes = (3 * DilithiumModeConstants::SEEDBYTES + m_l * m_polyeta_packedbytes + - m_k * m_polyeta_packedbytes + m_k * DilithiumModeConstants::POLYT0_PACKEDBYTES); -} - } // namespace Botan diff --git a/src/lib/pubkey/dilithium/dilithium_common/dilithium_symmetric_primitives.h b/src/lib/pubkey/dilithium/dilithium_common/dilithium_symmetric_primitives.h index c752843aa27..0ebb2d881cc 100644 --- a/src/lib/pubkey/dilithium/dilithium_common/dilithium_symmetric_primitives.h +++ b/src/lib/pubkey/dilithium/dilithium_common/dilithium_symmetric_primitives.h @@ -12,198 +12,140 @@ #include -#include -#include - -#include -#include -#include +#include +#include +#include +#include namespace Botan { /** -* Adapter class that uses polymorphy to distinguish -* Dilithium "common" from Dilithium "AES" modes. -*/ -class Dilithium_Symmetric_Primitives { - public: - enum class XofType { k128, k256 }; - + * Wrapper type for the H() function calculating the message representative for + * the Dilithium signature scheme. This wrapper may be used multiple times. + * + * Namely: mu = H(tr || M) + */ +class DilithiumMessageHash { public: - static std::unique_ptr create(DilithiumMode mode); - - virtual ~Dilithium_Symmetric_Primitives() = default; + DilithiumMessageHash(DilithiumHashedPublicKey tr) : m_tr(std::move(tr)) { clear(); } - // H is same for all modes - secure_vector H(std::span seed, size_t out_len) const { - return SHAKE_256(out_len * 8).process(seed.data(), seed.size()); + std::string name() const { + return Botan::fmt("{}({})", m_shake.name(), DilithiumConstants::MESSAGE_HASH_BYTES * 8); } - // CRH is same for all modes - secure_vector CRH(std::span in, size_t out_len) const { - return SHAKE_256(out_len * 8).process(in.data(), in.size()); + void update(std::span data) { m_shake.update(data); } + + DilithiumMessageRepresentative final() { + scoped_cleanup clean([this]() { clear(); }); + return m_shake.output(DilithiumConstants::MESSAGE_HASH_BYTES); } - // ExpandMatrix always uses the 256 version of the XOF - secure_vector ExpandMask(std::span seed, uint16_t nonce, size_t out_len) const { - return XOF(XofType::k256, seed, nonce)->output(out_len); + private: + void clear() { + m_shake.clear(); + m_shake.update(m_tr); } - // Mode dependent function - virtual std::unique_ptr XOF(XofType type, std::span seed, uint16_t nonce) const = 0; + private: + DilithiumHashedPublicKey m_tr; + SHAKE_256_XOF m_shake; }; -enum DilithiumEta : uint32_t { Eta2 = 2, Eta4 = 4 }; - -// Constants and mode dependent values -class DilithiumModeConstants { +/** +* Adapter class that uses polymorphy to distinguish +* Dilithium "common" from Dilithium "AES" modes. +*/ +class Dilithium_Symmetric_Primitives { public: - static constexpr int32_t SEEDBYTES = 32; - static constexpr int32_t CRHBYTES = 64; - static constexpr int32_t N = 256; - static constexpr int32_t Q = 8380417; - static constexpr int32_t D = 13; - static constexpr int32_t ROOT_OF_UNITY = 1753; - static constexpr int32_t POLYT1_PACKEDBYTES = 320; - static constexpr int32_t POLYT0_PACKEDBYTES = 416; - static constexpr int32_t SHAKE128_RATE = 168; - static constexpr int32_t SHAKE256_RATE = 136; - static constexpr int32_t SHA3_256_RATE = 136; - static constexpr int32_t SHA3_512_RATE = 72; - static constexpr int32_t AES256CTR_BLOCKBYTES = 64; - static constexpr int32_t QINV = 58728449; - static constexpr int32_t ZETAS[DilithiumModeConstants::N] = { - 0, 25847, -2608894, -518909, 237124, -777960, -876248, 466468, 1826347, 2353451, -359251, - -2091905, 3119733, -2884855, 3111497, 2680103, 2725464, 1024112, -1079900, 3585928, -549488, -1119584, - 2619752, -2108549, -2118186, -3859737, -1399561, -3277672, 1757237, -19422, 4010497, 280005, 2706023, - 95776, 3077325, 3530437, -1661693, -3592148, -2537516, 3915439, -3861115, -3043716, 3574422, -2867647, - 3539968, -300467, 2348700, -539299, -1699267, -1643818, 3505694, -3821735, 3507263, -2140649, -1600420, - 3699596, 811944, 531354, 954230, 3881043, 3900724, -2556880, 2071892, -2797779, -3930395, -1528703, - -3677745, -3041255, -1452451, 3475950, 2176455, -1585221, -1257611, 1939314, -4083598, -1000202, -3190144, - -3157330, -3632928, 126922, 3412210, -983419, 2147896, 2715295, -2967645, -3693493, -411027, -2477047, - -671102, -1228525, -22981, -1308169, -381987, 1349076, 1852771, -1430430, -3343383, 264944, 508951, - 3097992, 44288, -1100098, 904516, 3958618, -3724342, -8578, 1653064, -3249728, 2389356, -210977, - 759969, -1316856, 189548, -3553272, 3159746, -1851402, -2409325, -177440, 1315589, 1341330, 1285669, - -1584928, -812732, -1439742, -3019102, -3881060, -3628969, 3839961, 2091667, 3407706, 2316500, 3817976, - -3342478, 2244091, -2446433, -3562462, 266997, 2434439, -1235728, 3513181, -3520352, -3759364, -1197226, - -3193378, 900702, 1859098, 909542, 819034, 495491, -1613174, -43260, -522500, -655327, -3122442, - 2031748, 3207046, -3556995, -525098, -768622, -3595838, 342297, 286988, -2437823, 4108315, 3437287, - -3342277, 1735879, 203044, 2842341, 2691481, -2590150, 1265009, 4055324, 1247620, 2486353, 1595974, - -3767016, 1250494, 2635921, -3548272, -2994039, 1869119, 1903435, -1050970, -1333058, 1237275, -3318210, - -1430225, -451100, 1312455, 3306115, -1962642, -1279661, 1917081, -2546312, -1374803, 1500165, 777191, - 2235880, 3406031, -542412, -2831860, -1671176, -1846953, -2584293, -3724270, 594136, -3776993, -2013608, - 2432395, 2454455, -164721, 1957272, 3369112, 185531, -1207385, -3183426, 162844, 1616392, 3014001, - 810149, 1652634, -3694233, -1799107, -3038916, 3523897, 3866901, 269760, 2213111, -975884, 1717735, - 472078, -426683, 1723600, -1803090, 1910376, -1667432, -1104333, -260646, -3833893, -2939036, -2235985, - -420899, -2286327, 183443, -976891, 1612842, -3545687, -554416, 3919660, -48306, -1362209, 3937738, - 1400424, -846154, 1976782}; - static constexpr int32_t kSerializedPolynomialByteLength = DilithiumModeConstants::N / 2 * 3; - - DilithiumModeConstants(DilithiumMode dimension); - - DilithiumModeConstants(const DilithiumModeConstants& other) : DilithiumModeConstants(other.m_mode) {} - - DilithiumModeConstants(DilithiumModeConstants&& other) = default; - DilithiumModeConstants& operator=(const DilithiumModeConstants& other) = delete; - DilithiumModeConstants& operator=(DilithiumModeConstants&& other) = default; - - // Getter - uint8_t k() const { return m_k; } - - uint8_t l() const { return m_l; } - - DilithiumEta eta() const { return m_eta; } - - size_t tau() const { return m_tau; } - - size_t poly_uniform_gamma1_nblocks() const { return m_poly_uniform_gamma1_nblocks; } - - size_t stream256_blockbytes() const { return m_stream256_blockbytes; } - - size_t stream128_blockbytes() const { return m_stream128_blockbytes; } - - size_t polyw1_packedbytes() const { return m_polyw1_packedbytes; } - - size_t omega() const { return m_omega; } - - size_t polyz_packedbytes() const { return m_polyz_packedbytes; } - - size_t gamma2() const { return m_gamma2; } - - size_t gamma1() const { return m_gamma1; } - - size_t beta() const { return m_beta; } - - size_t poly_uniform_eta_nblocks() const { return m_poly_uniform_eta_nblocks; } - - size_t poly_uniform_nblocks() const { return m_poly_uniform_nblocks; } + enum class XofType { k128, k256 }; - size_t polyeta_packedbytes() const { return m_polyeta_packedbytes; } + protected: + Dilithium_Symmetric_Primitives(size_t commitment_hash_length_bytes) : + m_commitment_hash_length_bytes(commitment_hash_length_bytes) {} - size_t public_key_bytes() const { return m_public_key_bytes; } + public: + static std::unique_ptr create(const DilithiumConstants& mode); - size_t crypto_bytes() const { return m_crypto_bytes; } + virtual ~Dilithium_Symmetric_Primitives() = default; + Dilithium_Symmetric_Primitives(const Dilithium_Symmetric_Primitives&) = delete; + Dilithium_Symmetric_Primitives& operator=(const Dilithium_Symmetric_Primitives&) = delete; + Dilithium_Symmetric_Primitives(Dilithium_Symmetric_Primitives&&) = delete; + Dilithium_Symmetric_Primitives& operator=(Dilithium_Symmetric_Primitives&&) = delete; - OID oid() const { return m_mode.object_identifier(); } + DilithiumMessageHash get_message_hash(DilithiumHashedPublicKey tr) const { + return DilithiumMessageHash(std::move(tr)); + } - DilithiumMode mode() const { return m_mode; } + DilithiumHashedPublicKey H(StrongSpan pk) const { + return H_256(DilithiumConstants::PUBLIC_KEY_HASH_BYTES, pk); + } - size_t private_key_bytes() const { return m_private_key_bytes; } + DilithiumSeedRhoPrime H(StrongSpan k, + StrongSpan mu) const { + return H_256(DilithiumConstants::SEED_RHOPRIME_BYTES, k, mu); + } - size_t nist_security_strength() const { return m_nist_security_strength; } + std::tuple H( + StrongSpan seed) const { + scoped_cleanup clean([this]() { m_xof.clear(); }); + m_xof.update(seed); + + // Note: The order of invocations in an initializer list is not + // guaranteed by the C++ standard. Hence, we have to store the + // results in variables to ensure the correct order of execution. + auto rho = m_xof.output(DilithiumConstants::SEED_RHO_BYTES); + auto rhoprime = m_xof.output(DilithiumConstants::SEED_RHOPRIME_BYTES); + auto k = m_xof.output(DilithiumConstants::SEED_SIGNING_KEY_BYTES); + return {std::move(rho), std::move(rhoprime), std::move(k)}; + } - // Wrapper - decltype(auto) H(std::span seed, size_t out_len) const { - return m_symmetric_primitives->H(seed, out_len); + DilithiumCommitmentHash H(StrongSpan mu, + StrongSpan w1) const { + return H_256(m_commitment_hash_length_bytes, mu, w1); } - secure_vector CRH(const std::span in) const { - return m_symmetric_primitives->CRH(in, DilithiumModeConstants::CRHBYTES); + SHAKE_256_XOF& H(StrongSpan seed) const { + m_xof_external.clear(); + m_xof_external.update(seed); + return m_xof_external; } - std::unique_ptr XOF_128(std::span seed, uint16_t nonce) const { - return this->m_symmetric_primitives->XOF(Dilithium_Symmetric_Primitives::XofType::k128, seed, nonce); + // Once Dilithium AES is removed, this could return a SHAKE_256_XOF and + // avoid the virtual method call. + Botan::XOF& H(StrongSpan seed, uint16_t nonce) const { + return XOF(XofType::k128, seed, nonce); } - std::unique_ptr XOF_256(std::span seed, uint16_t nonce) const { - return this->m_symmetric_primitives->XOF(Dilithium_Symmetric_Primitives::XofType::k256, seed, nonce); + // Once Dilithium AES is removed, this could return a SHAKE_128_XOF and + // avoid the virtual method call. + Botan::XOF& H(StrongSpan seed, uint16_t nonce) const { + return XOF(XofType::k256, seed, nonce); } - secure_vector ExpandMask(const secure_vector& seed, uint16_t nonce) const { - return this->m_symmetric_primitives->ExpandMask( - seed, nonce, poly_uniform_gamma1_nblocks() * stream256_blockbytes()); + protected: + /** + * Implemented by the derived classes to create the correct XOF instance. + * This is a customization point to enable support for the AES variant of + * Dilithium. This won't be standardized in the FIPS 204; ML-DSA always + * uses SHAKE. Once we decide to remove the AES variant, this virtual + * method can be removed. + */ + virtual Botan::XOF& XOF(XofType type, std::span seed, uint16_t nonce) const = 0; + + private: + template + OutT H_256(size_t outbytes, InTs&&... ins) const { + scoped_cleanup clean([this]() { m_xof.clear(); }); + (m_xof.update(ins), ...); + return m_xof.output(outbytes); } private: - DilithiumMode m_mode; - - uint16_t m_nist_security_strength; - - // generated matrix dimension is m_k x m_l - uint8_t m_k; - uint8_t m_l; - DilithiumEta m_eta; - int32_t m_tau; - int32_t m_beta; - int32_t m_gamma1; - int32_t m_gamma2; - int32_t m_omega; - int32_t m_stream128_blockbytes; - int32_t m_stream256_blockbytes; - int32_t m_poly_uniform_nblocks; - int32_t m_poly_uniform_eta_nblocks; - int32_t m_poly_uniform_gamma1_nblocks; - int32_t m_polyvech_packedbytes; - int32_t m_polyz_packedbytes; - int32_t m_polyw1_packedbytes; - int32_t m_polyeta_packedbytes; - int32_t m_private_key_bytes; - int32_t m_public_key_bytes; - int32_t m_crypto_bytes; - - // Mode dependent primitives - std::unique_ptr m_symmetric_primitives; + size_t m_commitment_hash_length_bytes; + mutable SHAKE_256_XOF m_xof; + mutable SHAKE_256_XOF m_xof_external; }; + } // namespace Botan #endif diff --git a/src/lib/pubkey/dilithium/dilithium_common/dilithium_types.h b/src/lib/pubkey/dilithium/dilithium_common/dilithium_types.h new file mode 100644 index 00000000000..e42a4bf86c1 --- /dev/null +++ b/src/lib/pubkey/dilithium/dilithium_common/dilithium_types.h @@ -0,0 +1,64 @@ +/* + * Crystals Kyber key encapsulation mechanism + * + * Strong Type definitions used throughout the Dilithium implementation + * + * (C) 2024 Jack Lloyd + * (C) 2024 René Meusel, Rohde & Schwarz Cybersecurity + * + * Botan is released under the Simplified BSD License (see license.txt) + */ + +#ifndef BOTAN_DILITHIUM_TYPES_H_ +#define BOTAN_DILITHIUM_TYPES_H_ + +#include +#include +#include + +namespace Botan { + +using DilithiumPolyNTT = Botan::CRYSTALS::Polynomial; +using DilithiumPolyVecNTT = Botan::CRYSTALS::PolynomialVector; +using DilithiumPolyMatNTT = Botan::CRYSTALS::PolynomialMatrix; + +using DilithiumPoly = Botan::CRYSTALS::Polynomial; +using DilithiumPolyVec = Botan::CRYSTALS::PolynomialVector; + +/// Principal seed used to generate Dilithium key pairs +using DilithiumSeedRandomness = Strong, struct DilithiumSeedRandomness_>; + +/// Public seed to sample the polynomial matrix A from +using DilithiumSeedRho = Strong, struct DilithiumPublicSeed_>; + +/// Private seed to sample the polynomial vectors s1 and s2 from +using DilithiumSeedRhoPrime = Strong, struct DilithiumSeedRhoPrime_>; + +/// Private seed K used during signing +using DilithiumSigningSeedK = Strong, struct DilithiumSeedK_>; + +/// Serialized private key data +using DilithiumSerializedPrivateKey = Strong, struct DilithiumSerializedPrivateKey_>; + +/// Serialized public key data (result of pkEncode(pk)) +using DilithiumSerializedPublicKey = Strong, struct DilithiumSerializedPublicKey_>; + +/// Hash value of the serialized public key data +/// (result of H(BytesToBits(pkEncode(pk)), also referred to as 'tr') +using DilithiumHashedPublicKey = Strong, struct DilithiumHashedPublicKey_>; + +/// Representation of the message to be signed +using DilithiumMessageRepresentative = Strong, struct DilithiumMessageRepresentative_>; + +/// Serialized signature data +using DilithiumSerializedSignature = Strong, struct DilithiumSerializedSignature_>; + +/// Serialized representation of a commitment w1 +using DilithiumSerializedCommitment = Strong, struct DilithiumSerializedCommitment_>; + +/// Hash of the message representative and the signer's commitment +using DilithiumCommitmentHash = Strong, struct DilithiumCommitmentHash_>; + +} // namespace Botan + +#endif diff --git a/src/lib/pubkey/dilithium/dilithium_common/info.txt b/src/lib/pubkey/dilithium/dilithium_common/info.txt index 03039830900..bb4d009fd19 100644 --- a/src/lib/pubkey/dilithium/dilithium_common/info.txt +++ b/src/lib/pubkey/dilithium/dilithium_common/info.txt @@ -13,10 +13,14 @@ dilithium.h -dilithium_polynomials.h +dilithium_algos.h +dilithium_constants.h +dilithium_polynomial.h dilithium_symmetric_primitives.h +dilithium_types.h -shake +pqcrystals +shake_xof