Skip to content

Commit

Permalink
ML-DSA stores the private key as a 32-byte seed
Browse files Browse the repository at this point in the history
  • Loading branch information
reneme committed Sep 17, 2024
1 parent 0ce9409 commit 823455d
Show file tree
Hide file tree
Showing 23 changed files with 482 additions and 3,636 deletions.
96 changes: 19 additions & 77 deletions src/lib/pubkey/dilithium/dilithium_common/dilithium.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,14 @@ bool DilithiumMode::is_available() const {

class Dilithium_Signature_Operation final : public PK_Ops::Signature {
public:
Dilithium_Signature_Operation(std::shared_ptr<Dilithium_PrivateKeyInternal> sk, bool randomized) :
m_priv_key(std::move(sk)),
Dilithium_Signature_Operation(DilithiumInternalKeypair keypair, bool randomized) :
m_keypair(std::move(keypair)),
m_randomized(randomized),
m_h(m_priv_key->mode().symmetric_primitives().get_message_hash(m_priv_key->tr())),
m_s1(ntt(m_priv_key->s1().clone())),
m_s2(ntt(m_priv_key->s2().clone())),
m_t0(ntt(m_priv_key->t0().clone())),
m_A(Dilithium_Algos::expand_A(m_priv_key->rho(), m_priv_key->mode())) {}
m_h(m_keypair.second->mode().symmetric_primitives().get_message_hash(m_keypair.first->tr())),
m_s1(ntt(m_keypair.second->s1().clone())),
m_s2(ntt(m_keypair.second->s2().clone())),
m_t0(ntt(m_keypair.second->t0().clone())),
m_A(Dilithium_Algos::expand_A(m_keypair.first->rho(), m_keypair.second->mode())) {}

void update(std::span<const uint8_t> input) override { m_h->update(input); }

Expand All @@ -139,13 +139,13 @@ class Dilithium_Signature_Operation final : public PK_Ops::Signature {
* application defined and "empty" by default and <= 255 bytes long.
*/
std::vector<uint8_t> sign(RandomNumberGenerator& rng) override {
auto scope = CT::scoped_poison(*m_priv_key);
auto scope = CT::scoped_poison(*m_keypair.second);

const auto mu = m_h->final();
const auto& mode = m_priv_key->mode();
const auto& mode = m_keypair.second->mode();
const auto& sympri = mode.symmetric_primitives();

const auto rhoprime = sympri.H_maybe_randomized(m_priv_key->signing_seed(), mu, maybe(rng));
const auto rhoprime = sympri.H_maybe_randomized(m_keypair.second->signing_seed(), mu, maybe(rng));
CT::poison(rhoprime);

// Note: nonce (as requested by `polyvecl_uniform_gamma1`) is actually just uint16_t
Expand Down Expand Up @@ -207,10 +207,10 @@ class Dilithium_Signature_Operation final : public PK_Ops::Signature {
throw Internal_Error("Dilithium signature loop did not terminate");
}

size_t signature_length() const override { return m_priv_key->mode().signature_bytes(); }
size_t signature_length() const override { return m_keypair.second->mode().signature_bytes(); }

AlgorithmIdentifier algorithm_identifier() const override {
return AlgorithmIdentifier(m_priv_key->mode().mode().object_identifier(),
return AlgorithmIdentifier(m_keypair.second->mode().mode().object_identifier(),
AlgorithmIdentifier::USE_EMPTY_PARAM);
}

Expand All @@ -226,7 +226,7 @@ class Dilithium_Signature_Operation final : public PK_Ops::Signature {
}

private:
std::shared_ptr<Dilithium_PrivateKeyInternal> m_priv_key;
DilithiumInternalKeypair m_keypair;
bool m_randomized;
std::unique_ptr<DilithiumMessageHash> m_h;

Expand Down Expand Up @@ -372,29 +372,6 @@ std::unique_ptr<PK_Ops::Verification> Dilithium_PublicKey::create_x509_verificat
throw Provider_Not_Found(algo_name(), provider);
}

namespace Dilithium_Algos {

namespace {

/**
* NIST FIPS 204, Algorithm 6, lines 5-7 (ML-DSA.KeyGen_internal)
*/
std::pair<DilithiumPolyVec, DilithiumPolyVec> compute_t1_and_t0(const DilithiumPolyMatNTT& A,
const DilithiumPolyVec& s1,
const DilithiumPolyVec& s2) {
auto t_hat = A * ntt(s1.clone());
t_hat.reduce();
auto t = inverse_ntt(std::move(t_hat));
t += s2;
t.conditional_add_q();

return Dilithium_Algos::power2round(t);
}

} // namespace

} // namespace Dilithium_Algos

/**
* NIST FIPS 204, Algorithm 1 (ML-DSA.KeyGen), and 6 (ML-DSA.KeyGen_internal)
*
Expand All @@ -408,60 +385,25 @@ std::pair<DilithiumPolyVec, DilithiumPolyVec> compute_t1_and_t0(const DilithiumP
Dilithium_PrivateKey::Dilithium_PrivateKey(RandomNumberGenerator& rng, DilithiumMode m) {
DilithiumConstants mode(m);
BOTAN_ARG_CHECK(mode.mode().is_available(), "Dilithium/ML-DSA mode is not available in this build");
const auto& sympriv = mode.symmetric_primitives();

const auto xi = rng.random_vec<DilithiumSeedRandomness>(DilithiumConstants::SEED_RANDOMNESS_BYTES);
CT::poison(xi);

auto [rho, rhoprime, key] = sympriv.H(xi); // TODO: Add two-bytes k and l (for domain separation)
CT::unpoison(rho); // rho is public (seed for the public matrix A)

const auto A = Dilithium_Algos::expand_A(rho, mode);
auto [s1, s2] = Dilithium_Algos::expand_s(rhoprime, mode);
auto [t1, t0] = Dilithium_Algos::compute_t1_and_t0(A, s1, s2);

CT::unpoison_all(t1, key, s1, s2, t0);
m_public = std::make_shared<Dilithium_PublicKeyInternal>(mode, rho, std::move(t1));
m_private = std::make_shared<Dilithium_PrivateKeyInternal>(
std::move(mode), std::move(rho), std::move(key), m_public->tr(), std::move(s1), std::move(s2), std::move(t0));
std::tie(m_public, m_private) = Dilithium_Algos::expand_keypair(
rng.random_vec<DilithiumSeedRandomness>(DilithiumConstants::SEED_RANDOMNESS_BYTES), std::move(mode));
}

Dilithium_PrivateKey::Dilithium_PrivateKey(const AlgorithmIdentifier& alg_id, std::span<const uint8_t> sk) :
Dilithium_PrivateKey(sk, DilithiumMode(alg_id.oid())) {}

Dilithium_PrivateKey::Dilithium_PrivateKey(std::span<const uint8_t> sk, DilithiumMode m) {
auto scope = CT::scoped_poison(sk);

DilithiumConstants mode(m);
BOTAN_ARG_CHECK(mode.mode().is_available(), "Dilithium/ML-DSA mode is not available in this build");
BOTAN_ARG_CHECK(sk.size() == mode.private_key_bytes(), "dilithium private key does not have the correct byte count");
m_private =
Dilithium_PrivateKeyInternal::decode(std::move(mode), StrongSpan<const DilithiumSerializedPrivateKey>(sk));

// Currently, Botan's Private_Key class inherits from Public_Key, forcing us
// to derive the public key from the private key here.

// rho is public (used in rejection sampling of matrix A)
CT::unpoison(m_private->rho());

const auto A = Dilithium_Algos::expand_A(m_private->rho(), m_private->mode());
auto [t1, _] = Dilithium_Algos::compute_t1_and_t0(A, m_private->s1(), m_private->s2());
CT::unpoison(t1);

m_public = std::make_shared<Dilithium_PublicKeyInternal>(m_private->mode(), m_private->rho(), std::move(t1));
CT::unpoison(*m_private);

if(m_public->tr() != m_private->tr()) {
throw Decoding_Error("Calculated dilithium public key hash does not match the one stored in the private key");
}
auto& codec = mode.keypair_codec();
std::tie(m_public, m_private) = codec.decode_keypair(sk, std::move(mode));
}

secure_vector<uint8_t> Dilithium_PrivateKey::raw_private_key_bits() const {
return this->private_key_bits();
}

secure_vector<uint8_t> Dilithium_PrivateKey::private_key_bits() const {
return std::move(m_private->raw_sk().get());
return m_private->mode().keypair_codec().encode_keypair({m_public, m_private});
}

std::unique_ptr<PK_Ops::Signature> Dilithium_PrivateKey::create_signature_op(RandomNumberGenerator& rng,
Expand All @@ -476,7 +418,7 @@ std::unique_ptr<PK_Ops::Signature> Dilithium_PrivateKey::create_signature_op(Ran
// We might even drop support for the deterministic variant.
const bool randomized = (params == "Randomized");
if(provider.empty() || provider == "base") {
return std::make_unique<Dilithium_Signature_Operation>(m_private, randomized);
return std::make_unique<Dilithium_Signature_Operation>(DilithiumInternalKeypair{m_public, m_private}, randomized);
}
throw Provider_Not_Found(algo_name(), provider);
}
Expand Down
Loading

0 comments on commit 823455d

Please sign in to comment.